拿什麼拯救我的4G顯示卡:PyTorch節省視訊記憶體的策略總結


MLNLP 

機器學習演算法與自然語言處理 

)社群是國內外知名自然語言處理社群,受眾覆蓋國內外NLP碩博生、高校老師以及企業研究人員。


社群的願景 是促進國內外自然語言處理,機器學習學術界、產業界和廣大愛好者之間的交流,特別是初學者同學們的進步。

本文轉載自 | 極市平臺
作者 | OpenMMLab
來源丨https://zhuanlan.zhihu.com/p/430123077
0
『前言』
本文涉及到的 PyTorch 節省視訊記憶體的策略包括:
  • 混合精度訓練
  • 大 batch 訓練或者稱為梯度累加
  • gradient checkpointing 梯度檢查點
1
『混合精度訓練』
混合精度訓練全稱為 Automatic Mixed Precision,簡稱為 AMP,也就是我們常說的 FP16。在前系列解讀中已經詳細分析了 AMP 原理、原始碼實現以及 MMCV 中如何一行程式碼使用 AMP,具體連結見:
OpenMMLab:PyTorch 原始碼解讀之 torch.cuda.amp: 自動混合精度詳解
https://zhuanlan.zhihu.com/p/348554267
OpenMMLab:OpenMMLab 中混合精度訓練 AMP 的正確開啟方式
https://zhuanlan.zhihu.com/p/375224982
由於前面兩篇文章已經分析的非常詳細了,本文只簡要描述原理和具體說明用法。
考慮到訓練過程中梯度幅值大部分是非常小的,故訓練預設是 FP32 格式,如果能直接以 FP16 格式精度進行訓練,理論上可以減少一半的記憶體,達到加速訓練和採用更大 batch size 的目的,但是直接以 FP16 訓練會出現溢位問題,導致 NAN 或者引數更新失敗問題,而 AMP 的出現就是為了解決這個問題,其核心思想是 混合精度訓練+動態損失放大
  1. 維護一個 FP32 數值精度模型的副本
  2. 在每個 iteration
  • 複製並且轉換成 FP16 模型
  • 前向傳播(FP16 的模型引數),此時 weights, activations 都是 FP16
  • loss 乘 scale factor s
  • 反向傳播(FP16 的模型引數和引數梯度), 此時 gradients 也是 FP16
  • 引數梯度乘 1/s
  • 利用 FP16 的梯度更新 FP32 的模型引數
在 MMCV 中使用 AMP 分成兩種情況:
  • 在 OpenMMLab 上游庫例如 MMDetection 中使用 MMCV 的 AMP
  • 使用者只想簡單呼叫 MMCV 中的 AMP,而不依賴上游庫

(1) OpenMMLab 上游庫如何使用 MMCV 的 AMP

以 MMDectection 為例,用法非常簡單,只需要在配置中設定:

fp16 = dict(loss_scale=

512.

# 表示靜態 scale 

# 表示動態 scale 

fp16 = dict(loss_scale=

'dynamic'

)  



# 透過字典形式靈活開啟動態 scale 

fp16 = dict(loss_scale=dict(init_scale=

512.

,mode=

'dynamic'

))  


三種不同設定在大部分模型上效能都非常接近,如果不想設定 loss_scale,則可以簡單的採用 loss_scale='dynamic'

(2) 呼叫 MMCV 中的 AMP

直接呼叫 MMCV 中的 AMP,這通常意味著使用者可能在其他庫或者自己寫的程式碼庫中支援 AMP 功能。需要特別強調的是 PyTorch 官方僅僅在 1.6 版本及其之後版本中開始支援 AMP,而 MMCV 中的 AMP 支援 1.3 及其之後版本。如果你想在 1.3 或者 1.5 中使用 AMP,那麼使用 MMCV 是個非常不錯的選擇。
使用 MMCV 的 AMP 功能,只需要遵循以下幾個步驟即可:
  1. 將 auto_fp16 裝飾器應用到 model 的 forward 函式上
  2. 設定模型的 fp16_enabled 為 True 表示開啟 AMP 訓練,否則不生效
  3. 如果開啟了 AMP,需要同時配置對應的 FP16 最佳化器配置 Fp16OptimizerHook
  4. 在訓練的不同時刻,呼叫 Fp16OptimizerHook,如果你同時使用了 MMCV 中的 Runner 模組,那麼直接將第 3 步的引數輸入到 Runner 中即可
  5. (可選) 如果對應某些 OP 希望強制執行在 FP32 上,則可以在對應位置引入 force_fp32 裝飾器
# 1 作用到 forward 函式中
classExampleModule(nn.Module):

    @auto_fp16()
defforward(self, x, y):
return

 x, y



# 2 如果開啟 AMP,則需要加入開啟標誌

model.fp16_enabled = 

True

# 3 配置 Fp16OptimizerHook

optimizer_config = Fp16OptimizerHook(


    **cfg.optimizer_config, **fp16_cfg, distributed=distributed)



# 4 傳遞給 runner

runner.register_training_hooks(cfg.lr_config, optimizer_config,


                               cfg.checkpoint_config, cfg.log_config,


                               cfg.get(

'momentum_config'

None

))   



# 5 可選
classExampleModule(nn.Module):

    @auto_fp16()
defforward(self, x, y):

        features=self._forward(x, y)


        loss=self._loss(features,labels)


return

 loss



def_forward(self, x, y):
pass


    @force_fp32(apply_to=('features',))
def_loss(features,labels) :
pass

    注意 force_fp32 要生效,依然需要 fp16_enabled 為 True 才生效。
2
『大Batch訓練(梯度累加)』
大 Batch 訓練通常也稱為梯度累加策略,通常 PyTorch 一次迭代訓練流程為:

y_pred = model(xx)

loss = loss_fn(y_pred, y)

loss.backward()

optimizer.step() 

optimizer.zero_grad()

而梯度累加策略下常見的一次迭代訓練流程為:

y_pred = model(xx)

loss = loss_fn(y_pred, y)
loss = loss / cumulative_iters

loss.backward()

if

 current_iter % cumulative_iters==

0

    optimizer.step() 

    optimizer.zero_grad()

其核心思想就是對前幾次梯度進行累加,然後再統一進行引數更新,從而變相實現大 batch size 功能。需要注意的是如果模型中包括 BN 等考慮 batch 資訊的層,那麼效能可能會有輕微的差距。
細節可以參考:
https://github.com/open-mmlab/mmcv/pull/1221
在 MMCV 中已經實現了梯度累加功能,其核心程式碼位於 mmcv/runner/hooks/optimizer.py
GradientCumulativeOptimizerHook 中,和 AMP 實現一樣是採用 Hook 實現的。使用方法和 AMP 類似,只需要將第一節中的 Fp16OptimizerHook 替換為 GradientCumulativeOptimizerHook 或者 GradientCumulativeFp16OptimizerHook 即可。其核心實現如下所示:
@HOOKS.register_module()
classGradientCumulativeOptimizerHook(OptimizerHook):
def__init__(self, cumulative_iters=1, **kwargs):

        self.cumulative_iters = cumulative_iters


        self.divisible_iters = 

0# 剩餘的可以被 cumulative_iters 整除的訓練迭代次數

        self.remainder_iters = 

0# 剩餘累加次數

        self.initialized = 

False

defafter_train_iter(self, runner):
# 只需要執行一次即可
ifnot

 self.initialized:


            self._init(runner)



if

 runner.iter < self.divisible_iters:


            loss_factor = self.cumulative_iters


else

:


            loss_factor = self.remainder_iters



        loss = runner.outputs[

'loss'

]


        loss = loss / loss_factor


        loss.backward()



if

 (self.every_n_iters(runner, self.cumulative_iters)


or

 self.is_last_iter(runner)):



            runner.optimizer.step()


            runner.optimizer.zero_grad()    




def_init(self, runner):

        residual_iters = runner.max_iters - runner.iter



        self.divisible_iters = (


            residual_iters // self.cumulative_iters * self.cumulative_iters)


        self.remainder_iters = residual_iters - self.divisible_iters



        self.initialized = 

True

需要明白 divisible_iters 和 remainder_iters 的含義:

(1) 從頭訓練

此時在開始訓練時 iter=0,一共迭代 max_iters=102 次,梯度累加次數是 4,由於 102 無法被 4 整除,也就是最後的 102-(102 // 4)*4=2 個迭代是額外需要考慮的,在最後 2 個訓練迭代中 loss_factor 不能除以 4,而是 2,這樣才是最合理的做法。其中 remainder_iters=2,divisible_iters=100,residual_iters=102。

(2) resume 訓練

假設在梯度累加的中途退出,然後進行 resume 訓練,此時 iter 不是 0,由於最佳化器物件需要重新初始化,為了保證剩餘的不能被累加次數的訓練迭代次數能夠正常計算,需要重新計算 residual_iters。
3
『梯度檢查點』
梯度檢查點是一種用訓練時間換取視訊記憶體的辦法,其核心原理是在反向傳播時重新計算神經網路的中間啟用值而不用在前向時儲存,torch.utils.checkpoint 包中已經實現了對應功能。簡要實現過程是:在前向階段傳遞到 checkpoint 中的 forward 函式會以 _torch.no_grad_模式執行,並且僅僅儲存輸入引數和 forward 函式,在反向階段重新計算其 forward 輸出值。
具體用法非常簡單,以 ResNet 的 BasicBlock 為例:
defforward(self, x):
def_inner_forward(x):

        identity = x


        out = self.conv1(x)


        out = self.norm1(out)


        out = self.relu(out)


        out = self.conv2(out)


        out = self.norm2(out)


if

 self.downsample 

isnotNone

:


            identity = self.downsample(x)


        out += identity


return

 out



# x.requires_grad 這個判斷很有必要
if

 self.with_cp 

and

 x.requires_grad:


        out = cp.checkpoint(_inner_forward, x)


else

:


        out = _inner_forward(x)


    out = self.relu(out)


return

 out


self.with_cp 為 True,表示要開啟梯度檢查點功能。
checkpoint 在用法上面需要注意以下幾點:
  1. 模型的第一層不能用 checkpoint 或者說 forward 輸入中不能所有輸入的 requires_grad 屬性都是 False,因為其內部實現是依靠輸入的 requires_grad 屬性來判斷輸出返回是否需要梯度,而通常模型第一層輸入是 image tensor,其 requires_grad 通常是 False。一旦你第一層用了 checkpoint,那麼意味著這個 forward 函式不會有任何梯度,也就是說不會進行任何引數更新,沒有任何使用的必要,具體見 https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271。如果第一層用了 checkpoint, PyTorch 會列印 None of the inputs have requires_grad=True. Gradients will be Non 警告
  2. 對於 dropout 這種 forward 存在隨機性的層,需要保證 preserve_rng_state 為 True (預設就是 True,所以不用擔心),一旦標誌位設定為 True,在 forward 會儲存 RNG 狀態,然後在反向傳播的時候讀取該 RNG,保證兩次 forward 輸出一致。如果你確定不需要儲存 RNG,則可以設定 preserve_rng_state 為 False,省掉一些不必要的執行邏輯
  3. 其他注意事項,可以參考官方文件 https://pytorch.org/docs/stable/checkpoint.html#
其核心實現如下所示:
classCheckpointFunction(torch.autograd.Function):

    @staticmethod
defforward(ctx, run_function, preserve_rng_state, *args):
# 檢查輸入引數是否需要梯度

        check_backward_validity(args)


# 儲存必要的狀態

        ctx.run_function = run_function


        ctx.save_for_backward(*args)


with

 torch.no_grad():


# 以 no_grad 模型執行一遍

            outputs = run_function(*args)


return

 outputs



    @staticmethod
defbackward(ctx, *args):
# 讀取輸入引數

        inputs = ctx.saved_tensors


# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward.  Restore the surrounding state
# when we're done.

        rng_devices = []


with

 torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):


# detach 掉當前不需要考慮的節點

            detached_inputs = detach_variable(inputs)


# 重新執行一遍
with

 torch.enable_grad():


                outputs = ctx.run_function(*detached_inputs)



if

 isinstance(outputs, torch.Tensor):


            outputs = (outputs,)


# 計算該子圖梯度

        torch.autograd.backward(outputs, args)


        grads = tuple(inp.grad 

if

 isinstance(inp, torch.Tensor) 

else

 inp


for

 inp 

in

 detached_inputs)


return

 (

None

None

) + grads


4
『實驗驗證』
為了驗證上述策略是否真的能夠省視訊記憶體,採用 mmdetection 庫進行驗證,基本環境如下:

顯示卡: GeForce GTX 1660

PyTorch: 1.7.1

CUDA Runtime 10.1

MMCV: 1.3.16

MMDetection: 2.17.0

(1) base

  • 資料集:pascal voc
  • 演算法是 retinanet,對應配置檔案為 retinanet_r50_fpn_1x_voc0712.py
  • 為了防止 lr 過大導致訓練出現 nan,需要將 lr 設定為 0.01/8=0.00125
  • bs 設定為 2

(2) 混合精度 AMP

在 base 配置基礎上新增如下配置即可:

fp16 = dict(loss_scale=512.)

(3) 梯度累加

在 base 配置基礎上替換 optimizer_config 為如下:
# 累加2次

optimizer_config = dict(

type

=

'GradientCumulativeOptimizerHook'

, cumulative_iters=2)

(4) 梯度檢查點

在 base 配置基礎上在 backbone 部分開啟 with_cp 標誌即可:

model = dict(backbone=dict(with_cp=True),

             bbox_head=dict(num_classes=20))

每個實驗總共迭代 1300 次,統計佔用視訊記憶體、訓練總時長。
配置 視訊記憶體佔用(MB) 訓練時長
base 2900 7 分 45 秒
混合精度 AMP 2243 36 分
梯度累加 3177 7 分 32 秒
梯度檢查點 2590 8 分 37 秒
  1. 對比 base 和 AMP 可以發現,由於實驗顯示卡是不支援 AMP 的,故只能節省視訊記憶體,速度會特別慢,如果本身顯示卡支援 AMP 則可以實現在節省視訊記憶體的同時提升訓練速度
  2. 對比 base 和梯度累加可以發現,在相同 bs 情況下,梯度累加 2 次相當於 bs 擴大一倍,但是視訊記憶體增加不多。如果將 bs 縮小一倍,則可以實現在相同 bs 情況下節省大概一倍視訊記憶體
  3. 對比 base 和梯度檢查點可以發現,可以節省一定的視訊記憶體,但是訓練時長會增加一些
從上面簡單實驗可以發現,AMP、梯度累加和梯度檢查點確實可以在不同程度減少視訊記憶體,而且這三個策略是正交的,可以同時使用。
5
『總結』
本文簡要描述了三個在 MMCV 中整合且可以透過配置一行開啟的節省視訊記憶體策略,這三個策略比較常用也比較成熟。隨著模型規模的不斷增長,也出現了很多新的策略,例如模型引數壓縮、動態視訊記憶體最佳化、使用 CPU 記憶體暫存策略以及分散式情況下 PyTorch 1.10 最新支援的 ZeroRedundancyOptimizer 等等。
快速連結直達 MMCV 演算法庫,歡迎大家 Star:
https://github.com/open-mmlab/mmcv
技術交流群邀請函
△長按新增小助手
掃描二維碼新增小助手微信
請備註:姓名-學校/公司-研究方向
(如:小張-哈工大-對話系統)
即可申請加入自然語言處理/Pytorch等技術交流群

關於我們

MLNLP社群  機器學習演算法與自然語言處理 ) 是由國內外自然語言處理學者聯合構建的民間學術社群,目前已經發展為國內外知名自然語言處理社群,旗下包括  萬人頂會交流群、AI臻選匯、AI英才匯  以及  AI學術匯  等知名品牌,旨在促進機器學習,自然語言處理學術界、產業界和廣大愛好者之間的進步。
社群可以為相關從業者的深造、就業及研究等方面提供開放交流平臺。歡迎大家關注和加入我們。

相關文章