Pytorch爆显存的一些常见解决办法
AI摘要: 本文针对PyTorch训练过程中显存爆炸的问题提供了两种解决方案:一是通过`torch.cuda.memory_allocated(device)`逐步骤监控显存占用情况以定位瓶颈;二是分析计算图累积机制,指出在循环中不当操作(如append/累加loss)会导致历史变量长期驻留内存,并强调需使用detach或item主动释放张量而非仅依赖no_grad()。文中还展示了自定义Trainer类的实现案例,说明如何避免hidden_states等中间结果被意外保留。
torch.cuda.memory_allocated(device)
通过在Forward的每个步骤log出此时占用的显存,就能知道哪一步的显存占用最大。
for i, batch in enumerate(train_loader):
print("1:", torch.cuda.memory_allocated(0))
outputs = model(**batch)
print("2:", torch.cuda.memory_allocated(0))
loss = outputs.loss
print("3:", torch.cuda.memory_allocated(0))
loss.backward()
print("4:", torch.cuda.memory_allocated(0))
optimizer.step()
optimizer.zero_grad()
print("5:", torch.cuda.memory_allocated(0))
累计计算图
Pytorch中的很多变量背后都是放在计算图中,因此很容易误导致计算图累计。
比如在for循环中append loss,或者累加loss,那么此时,就会导致计算图累计,之前的所有变量都将保存,最终爆显存。
with torch.no_grad()
可以组织计算梯度,但是并不会释放累计的张量,只有detach或者item才能。
import swanlab
from swanlab.integration.transformers import SwanLabCallback
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer
from torch.utils.data import DataLoader, Sampler
import torch
import math
import random
swanlab_callback = SwanLabCallback(
project="map",
experiment_name=config['name'],
)
callbacks = []
if ENV == "SERVER":
callbacks.append(swanlab_callback)
class SupConTrainer(Trainer):
def __init__(self, contrastive_weight=0.1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.contrastive_weight = contrastive_weight
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# ⚡ 训练阶段需要 hidden states
# print("1:", torch.cuda.memory_allocated(0))
# 标准的 forward
outputs = model(**inputs, output_hidden_states=True)
# print("outputs:", outputs)
# print("2:", torch.cuda.memory_allocated(0))
logits = outputs.logits
labels = inputs["labels"]
# -------- 1. 交叉熵损失 --------
ce_loss = F.cross_entropy(logits, labels)
# print("3:", torch.cuda.memory_allocated(0))
# -------- 2. 对比学习损失 --------
# 取最后一层 hidden states (batch_size, hidden_dim)
hidden_states = outputs.hidden_states[-1][:, 0, :] # [CLS]向量
# 单位化
hidden_states = F.normalize(hidden_states, dim=-1)
# print("4:", torch.cuda.memory_allocated(0))
# 相似度矩阵 (batch, batch)
similarity_matrix = torch.matmul(hidden_states, hidden_states.T)
# print("5:", torch.cuda.memory_allocated(0))
# 只取同类样本作为正样本
mask = labels.unsqueeze(0) == labels.unsqueeze(1) # (batch, batch)
# InfoNCE / SupCon 损失
logits_contrastive = similarity_matrix / 0.1 # 温度参数 0.1
contrastive_loss = -torch.log_softmax(logits_contrastive, dim=1)[mask].mean()
# print("6:", torch.cuda.memory_allocated(0))
# -------- 3. 总损失 --------
loss = ce_loss + self.contrastive_weight * contrastive_loss
# print("7:", torch.cuda.memory_allocated(0))
# else:
# # ⚡ eval阶段只需要 logits,避免返回hidden_states
# outputs = model(**inputs, output_hidden_states=False)
# logits = outputs.logits
# labels = inputs["labels"]
# loss = self.ce_loss(logits, labels)
outputs = {
"loss": outputs['loss'],
"logits": logits,
}
return (loss, outputs) if return_outputs else loss
trainer = SupConTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=tokenizer,
compute_metrics=compute_map3,
# data_collator=data_collator,
callbacks=callbacks,
contrastive_weight= config['contrastive_weight'], # 控制对比损失权重
)
trainer.train()
上面是魔改的trainer函数,如果outputs不修改,那么就会将hidden_state保留并传入到外部,导致hidden_state不断累计,最终OOM。