节约显存
训练节约显存
将计算
loss
的部分放入forward
的部分训练一个任务的时候会有一张卡为主节点,接受其他节点的数据,占用显存会多一些.当主节点显存占用是从节点的 2 倍的时候,可以采用一下方式将每张卡都用满
任务编号 使用卡编号 0 0,2,3 1 1,2,3 混合精度
1
2
3for epoch in range(epochs):
with autocast():
x0 = fun0(x)检查代码将不需要梯度下降的部分使用
1
2with torch.no_grad():
a = b + 1避免中间变量
1
2
3
4
5
6
7# 错误示例
x0 = fun0(x)
x1 = fun1(x0)
# 改善为以下
x = fun0(x)
x = fun1(x)清理空间
1
2
3
4
5torch.cuda.empty_cache()
optimizer.zero_grad()
import gc
gc.collect()
节约显存
http://home.ustc.edu.cn/~ustcxwy0271/2022/10/30/train-tricks-1/