节约显存

训练节约显存

  • 将计算 loss 的部分放入forward 的部分

  • 训练一个任务的时候会有一张卡为主节点,接受其他节点的数据,占用显存会多一些.当主节点显存占用是从节点的 2 倍的时候,可以采用一下方式将每张卡都用满

    任务编号 使用卡编号
    0 0,2,3
    1 1,2,3
  • 混合精度

    1
    2
    3
    for epoch in range(epochs):
    with autocast():
    x0 = fun0(x)
  • 检查代码将不需要梯度下降的部分使用

    1
    2
    with torch.no_grad():
    a = b + 1
  • 避免中间变量

    1
    2
    3
    4
    5
    6
    7
    # 错误示例
    x0 = fun0(x)
    x1 = fun1(x0)

    # 改善为以下
    x = fun0(x)
    x = fun1(x)
  • 清理空间

    1
    2
    3
    4
    5
    torch.cuda.empty_cache()
    optimizer.zero_grad()

    import gc
    gc.collect()

节约显存
http://home.ustc.edu.cn/~ustcxwy0271/2022/10/30/train-tricks-1/
作者
Xu Weiye
发布于
2022年10月30日
许可协议