pytorch_lightning模型訓練加速技巧與漲點技巧


MLNLP 

機器學習演算法與自然語言處理 

)社群是國內外知名自然語言處理社群,受眾覆蓋國內外NLP碩博生、高校老師以及企業研究人員。


社群的願景 是促進國內外自然語言處理,機器學習學術界、產業界和廣大愛好者之間的交流,特別是初學者同學們的進步。

轉載自 | 演算法美食屋
作者 | 梁雲1991
pytorch-lightning 是建立在pytorch之上的高層次模型介面。
pytorch-lightning 之於 pytorch,就如同keras之於 tensorflow.
pytorch-lightning 有以下一些引人注目的功能:
  • 可以不必編寫自定義迴圈,只要指定loss計算方法即可。
  • 可以透過callbacks非常方便地新增CheckPoint引數儲存、early_stopping 等功能。
  • 可以非常方便地在單CPU、多CPU、單GPU、多GPU乃至多TPU上訓練模型。
  • 可以透過呼叫torchmetrics庫,非常方便地新增Accuracy,AUC,Precision等各種常用評估指標。
  • 可以非常方便地實施多批次梯度累加、半精度混合精度訓練、最大batch_size自動搜尋等技巧,加快訓練過程。
  • 可以非常方便地使用SWA(隨機引數平均)、CyclicLR(學習率週期性排程策略)與auto_lr_find(最優學習率發現)等技巧 實現模型漲點。
一般按照如下方式 安裝和 引入 pytorch-lightning 庫。
#安裝

pip install pytorch-lightning

#引入
import

 pytorch_lightning 

as

 pl 

顧名思義,它可以幫助我們漂亮(pl)地進行深度學習研究。😋😋
You do the research. Lightning will do everything else.⭐️⭐️
參考文件:
  • pl_docs: https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html
  • pl_template:https://github.com/PyTorchLightning/deep-learning-project-template
  • torchmetrics: https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html
1
『pytorch-lightning的設計哲學』
pytorch-lightning 的核心設計哲學是將 深度學習專案中的 研究程式碼(定義模型) 和 工程程式碼 (訓練模型) 相互分離。
使用者只需專注於研究程式碼(pl.LightningModule)的實現,而工程程式碼藉助訓練工具類(pl.Trainer)統一實現。
更詳細地說,深度學習專案程式碼可以分成如下4部分:
  • 研究程式碼 (Research code),使用者繼承LightningModule實現。
  • 工程程式碼 (Engineering code),使用者無需關注透過呼叫Trainer實現。
  • 非必要程式碼 (Non-essential research code,logging, etc…),使用者透過呼叫Callbacks實現。
  • 資料 (Data),使用者透過torch.utils.data.DataLoader實現,也可以封裝成pl.LightningDataModule。
2
『pytorch-lightning的使用範例』
下面我們使用minist圖片分類問題為例,演示pytorch-lightning的最佳實踐。

1,準備資料

import

 torch 

from

 torch 

import

 nn 

from

 torchvision 

import

 transforms 

as

 T

from

 torchvision.datasets 

import

 MNIST

from

 torch.utils.data 

import

 DataLoader,random_split

import

 pytorch_lightning 

as

 pl 

from

 torchmetrics 

import

 Accuracy 

classMNISTDataModule(pl.LightningDataModule):
def__init__

(self, data_dir: str = 

"./minist/"

                 batch_size: int = 

32

,

                 num_workers: int =

4

)

:

        super().__init__()

        self.data_dir = data_dir

        self.batch_size = batch_size

        self.num_workers = num_workers

defsetup(self, stage = None):

        transform = T.Compose([T.ToTensor()])

        self.ds_test = MNIST(self.data_dir, train=

False

,transform=transform,download=

True

)

        self.ds_predict = MNIST(self.data_dir, train=

False

,transform=transform,download=

True

)

        ds_full = MNIST(self.data_dir, train=

True

,transform=transform,download=

True

)

        self.ds_train, self.ds_val = random_split(ds_full, [

55000

5000

])

deftrain_dataloader(self):
return

 DataLoader(self.ds_train, batch_size=self.batch_size,

                          shuffle=

True

, num_workers=self.num_workers,

                          pin_memory=

True

)

defval_dataloader(self):
return

 DataLoader(self.ds_val, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=

True

)

deftest_dataloader(self):
return

 DataLoader(self.ds_test, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=

True

)

defpredict_dataloader(self):
return

 DataLoader(self.ds_predict, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=

True

)

data_mnist = MNISTDataModule()

data_mnist.setup()

for

 features,labels 

in

 data_mnist.train_dataloader():

    print(features.shape)

    print(labels.shape)

break

torch.Size([32, 1, 28, 28])

torch.Size([32])

2,定義模型

net = nn.Sequential(

    nn.Conv2d(in_channels=

1

,out_channels=

32

,kernel_size = 

3

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Conv2d(in_channels=

32

,out_channels=

64

,kernel_size = 

5

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Dropout2d(p = 

0.1

),

    nn.AdaptiveMaxPool2d((

1

,

1

)),

    nn.Flatten(),

    nn.Linear(

64

,

32

),

    nn.ReLU(),

    nn.Linear(

32

,

10

)

)

classModel(pl.LightningModule):

def__init__(self,net,learning_rate=1e-3):

        super().__init__()

        self.save_hyperparameters()

        self.net = net

        self.train_acc = Accuracy()

        self.val_acc = Accuracy()

        self.test_acc = Accuracy() 

defforward(self,x):

        x = self.net(x)

return

 x

#定義loss
deftraining_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

#定義各種metrics
deftraining_step_end(self,outputs):

        train_acc = self.train_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"train_acc"

,train_acc,prog_bar=

True

)

return

 {

"loss"

:outputs[

"loss"

].mean()}

#定義optimizer,以及可選的lr_scheduler
defconfigure_optimizers(self):
return

 torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

defvalidation_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

defvalidation_step_end(self,outputs):

        val_acc = self.val_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"val_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)

        self.log(

"val_acc"

,val_acc,prog_bar=

True

,on_epoch=

True

,on_step=

False

)

deftest_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

deftest_step_end(self,outputs):

        test_acc = self.test_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"test_acc"

,test_acc,on_epoch=

True

,on_step=

False

)

        self.log(

"test_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)
model = Model(net)

#檢視模型大小

model_size = pl.utilities.memory.get_model_size_mb(model)

print(

"model_size = {} M \n"

.format(model_size))

model.example_input_array = [features]

summary = pl.utilities.model_summary.ModelSummary(model,max_depth=

-1

)

print(summary) 

model_size = 0.218447 M 
   | Name      | Type              | Params | In sizes         | Out sizes       

---------------------------------------------------------------------------------------

0  | net       | Sequential        | 54.0 K | [32, 1, 28, 28]  | [32, 10]        

1  | net.0     | Conv2d            | 320    | [32, 1, 28, 28]  | [32, 32, 26, 26]

2  | net.1     | MaxPool2d         | 0      | [32, 32, 26, 26] | [32, 32, 13, 13]

3  | net.2     | Conv2d            | 51.3 K | [32, 32, 13, 13] | [32, 64, 9, 9]  

4  | net.3     | MaxPool2d         | 0      | [32, 64, 9, 9]   | [32, 64, 4, 4]  

5  | net.4     | Dropout2d         | 0      | [32, 64, 4, 4]   | [32, 64, 4, 4]  

6  | net.5     | AdaptiveMaxPool2d | 0      | [32, 64, 4, 4]   | [32, 64, 1, 1]  

7  | net.6     | Flatten           | 0      | [32, 64, 1, 1]   | [32, 64]        

8  | net.7     | Linear            | 2.1 K  | [32, 64]         | [32, 32]        

9  | net.8     | ReLU              | 0      | [32, 32]         | [32, 32]        

10 | net.9     | Linear            | 330    | [32, 32]         | [32, 10]        

11 | train_acc | Accuracy          | 0      | ?                | ?               

12 | val_acc   | Accuracy          | 0      | ?                | ?               

13 | test_acc  | Accuracy          | 0      | ?                | ?               

---------------------------------------------------------------------------------------

54.0 K    Trainable params

0         Non-trainable params

54.0 K    Total params

0.216     Total estimated model params size (MB)

3,訓練模型

pl.seed_everything(

1234

)
ckpt_callback = pl.callbacks.ModelCheckpoint(

    monitor=

'val_loss'

,

    save_top_k=

1

,

    mode=

'min'

)

early_stopping = pl.callbacks.EarlyStopping(monitor = 

'val_loss'

,

               patience=

3

,

               mode = 

'min'

)

# gpus=0 則使用cpu訓練,gpus=1則使用1個gpu訓練,gpus=2則使用2個gpu訓練,gpus=-1則使用所有gpu訓練,
# gpus=[0,1]則指定使用0號和1號gpu訓練, gpus="0,1,2,3"則使用0,1,2,3號gpu訓練
# tpus=1 則使用1個tpu訓練

trainer = pl.Trainer(max_epochs=

20

,   

#gpus=0, #單CPU模式

     gpus=

0

#單GPU模式
#num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(程序)模式
#gpus=[0,1,2,3],strategy="dp", #多GPU的DataParallel(速度提升效果一般)
#gpus=[0,1,2,3],strategy=“ddp_find_unused_parameters_false" #多GPU的DistributedDataParallel(速度提升效果好)

     callbacks = [ckpt_callback,early_stopping],

     profiler=

"simple"

#斷點續訓
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

#訓練模型

trainer.fit(model,data_mnist)

Epoch 8: 100%

1876/1876 [01:44<00:00, 17.93it/s, loss=0.0603, v_num=0, train_acc=1.000, val_acc=0.985]

4,評估模型

result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path=

'best'

)

--------------------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9966545701026917, 

'test_loss'

: 0.010617421939969063}

--------------------------------------------------------------------------------

result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path=

'best'

)

--------------------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9865999817848206, 

'test_loss'

: 0.042671505361795425}

--------------------------------------------------------------------------------

result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path=

'best'

)

--------------------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.987500011920929, 

'test_loss'

: 0.047178059816360474}

--------------------------------------------------------------------------------

5,使用模型

data,label = next(iter(data_module.test_dataloader()))

model.eval()

prediction = model(data)

print(prediction)

tensor([[-13.0112,  -2.8257,  -1.8588,  -3.6137,  -0.3307,  -5.4953, -19.7282,

          15.9651,  -8.0379,  -2.2925],

        [ -6.0261,  -2.5480,  13.4140,  -5.5701, -10.2049,  -6.4469,  -3.7119,

          -6.0732,  -6.0826,  -7.7339],

          ...

        [-16.7028,  -4.9060,   0.4400,  24.4337, -12.8793,   1.5085, -17.9232,

          -3.0839,   0.5491,   1.9846],

        [ -5.0909,  10.1805,  -8.2528,  -9.2240,  -1.8044,  -4.0296,  -8.2297,

          -3.1828,  -5.9361,  -4.8410]], grad_fn=<AddmmBackward0>)

6,儲存模型

最優模型預設儲存在 trainer.checkpoint_callback.best_model_path 的目錄下,可以直接載入。

print(trainer.checkpoint_callback.best_model_path)

print(trainer.checkpoint_callback.best_model_score)

lightning_logs/version_10/checkpoints/epoch=8-step=15470.ckpt

tensor(0.0376, device=

'cuda:0'

)

model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

trainer_clone = pl.Trainer(max_epochs=

3

,gpus=

1

result = trainer_clone.test(model_clone,data_module.test_dataloader())

print(result)

--------------------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9887999892234802, 

'test_loss'

: 0.03627564385533333}

--------------------------------------------------------------------------------

[{

'test_acc'

: 0.9887999892234802, 

'test_loss'

: 0.03627564385533333}]

3
『訓練加速技巧』
下面重點介紹pytorch_lightning 模型訓練加速的一些技巧。
  • 1,使用多程序讀取資料(num_workers=4)
  • 2,使用鎖業記憶體(pin_memory=True)
  • 3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")
  • 4,使用梯度累加(accumulate_grad_batches=6)
  • 5,使用半精度(precision=16,batch_size=2*batch_size)
  • 6,自動搜尋最大batch_size(auto_scale_batch_size='binsearch')
(注:過大的batch_size對模型學習是有害的。)
詳細原理,可以參考:
https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html
我們將訓練程式碼封裝成如下指令碼形式,方便後面測試使用。

%%writefile mnist_cnn.py

import

 torch 

from

 torch 

import

 nn 

from

 argparse 

import

 ArgumentParser

import

 torchvision 

from

 torchvision 

import

 transforms 

as

 T

from

 torchvision.datasets 

import

 MNIST

from

 torch.utils.data 

import

 DataLoader,random_split

import

 pytorch_lightning 

as

 pl

from

 torchmetrics 

import

 Accuracy

#================================================================================
# 一,準備資料
#================================================================================

classMNISTDataModule(pl.LightningDataModule):
def__init__

(self, data_dir: str = 

"./minist/"

                 batch_size: int = 

32

,

                 num_workers: int =

4

,

                 pin_memory:bool =True)

:

        super().__init__()

        self.data_dir = data_dir

        self.batch_size = batch_size

        self.num_workers = num_workers

        self.pin_memory = pin_memory

defsetup(self, stage = None):

        transform = T.Compose([T.ToTensor()])

        self.ds_test = MNIST(self.data_dir, download=

True

,train=

False

,transform=transform)

        self.ds_predict = MNIST(self.data_dir, download=

True

, train=

False

,transform=transform)

        ds_full = MNIST(self.data_dir, download=

True

, train=

True

,transform=transform)

        self.ds_train, self.ds_val = random_split(ds_full, [

55000

5000

])

deftrain_dataloader(self):
return

 DataLoader(self.ds_train, batch_size=self.batch_size,

                          shuffle=

True

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

defval_dataloader(self):
return

 DataLoader(self.ds_val, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

deftest_dataloader(self):
return

 DataLoader(self.ds_test, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

defpredict_dataloader(self):
return

 DataLoader(self.ds_predict, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

    @staticmethod
defadd_dataset_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=

False

)

        parser.add_argument(

'--batch_size'

, type=int, default=

32

)

        parser.add_argument(

'--num_workers'

, type=int, default=

4

)

        parser.add_argument(

'--pin_memory'

, type=bool, default=

True

)

return

 parser

#================================================================================
# 二,定義模型
#================================================================================

net = nn.Sequential(

    nn.Conv2d(in_channels=

1

,out_channels=

32

,kernel_size = 

3

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Conv2d(in_channels=

32

,out_channels=

64

,kernel_size = 

5

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Dropout2d(p = 

0.1

),

    nn.AdaptiveMaxPool2d((

1

,

1

)),

    nn.Flatten(),

    nn.Linear(

64

,

32

),

    nn.ReLU(),

    nn.Linear(

32

,

10

)

)

classModel(pl.LightningModule):

def__init__(self,net,learning_rate=1e-3):

        super().__init__()

        self.save_hyperparameters()

        self.net = net

        self.train_acc = Accuracy()

        self.val_acc = Accuracy()

        self.test_acc = Accuracy() 

defforward(self,x):

        x = self.net(x)

return

 x

#定義loss
deftraining_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

#定義各種metrics
deftraining_step_end(self,outputs):

        train_acc = self.train_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"train_acc"

,train_acc,prog_bar=

True

)

return

 {

"loss"

:outputs[

"loss"

].mean()}

#定義optimizer,以及可選的lr_scheduler
defconfigure_optimizers(self):
return

 torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

defvalidation_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

defvalidation_step_end(self,outputs):

        val_acc = self.val_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"val_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)

        self.log(

"val_acc"

,val_acc,prog_bar=

True

,on_epoch=

True

,on_step=

False

)

deftest_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

deftest_step_end(self,outputs):

        test_acc = self.test_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"test_acc"

,test_acc,on_epoch=

True

,on_step=

False

)

        self.log(

"test_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)

    @staticmethod
defadd_model_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=

False

)

        parser.add_argument(

'--learning_rate'

, type=float, default=

1e-3

)

return

 parser

#================================================================================
# 三,訓練模型
#================================================================================

defmain(hparams):

    pl.seed_everything(

1234

)
    data_mnist = MNISTDataModule(batch_size=hparams.batch_size,

                                 num_workers=hparams.num_workers)
    model = Model(net,learning_rate=hparams.learning_rate)
    ckpt_callback = pl.callbacks.ModelCheckpoint(

        monitor=

'val_loss'

,

        save_top_k=

1

,

        mode=

'min'

    )

    early_stopping = pl.callbacks.EarlyStopping(monitor = 

'val_loss'

,

                   patience=

3

,

                   mode = 

'min'

)
    trainer = pl.Trainer.from_argparse_args( 

        hparams,

        max_epochs=

10

,
        callbacks = [ckpt_callback,early_stopping]

    ) 

if

 hparams.auto_scale_batch_size 

isnotNone

:

#搜尋不發生OOM的最大batch_size

        max_batch_size = trainer.tuner.scale_batch_size(model,data_mnist,

                        mode=hparams.auto_scale_batch_size)

        data_mnist.batch_size = max_batch_size

#等價於
#trainer.tune(model,data_mnist)

#gpus=0, #單CPU模式
#gpus=1, #單GPU模式
#num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(程序)模式
#gpus=4,strategy="dp", #多GPU(dp速度提升效果一般)
#gpus=4,strategy=“ddp_find_unused_parameters_false" #多GPU(ddp速度提升效果好)

    trainer.fit(model,data_mnist)

    result = trainer.test(model,data_mnist,ckpt_path=

'best'

)

if

 __name__ == 

"__main__"

:

    parser = ArgumentParser()

    parser = MNISTDataModule.add_dataset_args(parser)

    parser = Model.add_model_args(parser)

    parser = pl.Trainer.add_argparse_args(parser)

    hparams = parser.parse_args()

    main(hparams)

1,使用多程序讀取資料(num_workers=4)

使用多程序讀取資料,可以避免資料載入過程成為效能瓶頸。
  • 單程序讀取資料(num_workers=0, gpus=1): 1min 18s
  • 多程序讀取資料(num_workers=4, gpus=1): 59.7s

%%time

#單程序讀取資料(num_workers=0)

!python3 mnist_cnn.py --num_workers=

0

 --gpus=

1

------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9857000112533569, 

'test_loss'

: 0.04885349050164223}

--------------------------------------------------------------------------------
CPU 

times

: user 4.67 s, sys: 2.14 s, total: 6.81 s

Wall time: 2min 50s

%%time

#多程序讀取資料(num_workers=4)

!python3 mnist_cnn.py --num_workers=

4

 --gpus=

1

---------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9764000177383423, 

'test_loss'

: 0.0820135846734047}

--------------------------------------------------------------------------------

Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 163.40it/s]

CPU 

times

: user 1.56 s, sys: 647 ms, total: 2.21 s

Wall time: 59.7 s

2,使用鎖業記憶體(pin_memory=True)

鎖頁記憶體存放的內容在任何情況下都不會與主機的虛擬記憶體進行交換(注:虛擬記憶體就是硬碟)
因此鎖業記憶體比非鎖業記憶體讀寫效率更高,copy到GPU上也更快速。
當計算機的記憶體充足的時候,可以設定pin_memory=True。當系統卡住,或者交換記憶體使用過多的時候,設定pin_memory=False。
因為pin_memory與電腦硬體效能有關,pytorch開發者不能確保每一個煉丹玩家都有高階裝置,因此pin_memory預設為False。
  • 非鎖業記憶體儲存資料(pin_memory=False, gpus=1): 1min
  • 鎖業記憶體儲存資料(pin_memory=True, gpus=1): 59.5s

%%time

#非鎖業記憶體儲存資料(pin_memory=False)

!python3 mnist_cnn.py --pin_memory=

False

 --gpus=

1

----------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9812999963760376, 

'test_loss'

: 0.06231774762272835}

--------------------------------------------------------------------------------

Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 171.69it/s]

CPU 

times

: user 1.59 s, sys: 619 ms, total: 2.21 s

Wall time: 1min

%%time

#鎖業記憶體儲存資料(pin_memory=True)

!python3 mnist_cnn.py --pin_memory=

True

 --gpus=

1

---------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9757999777793884, 

'test_loss'

: 0.08017424494028091}

--------------------------------------------------------------------------------

Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 174.58it/s]

CPU 

times

: user 1.54 s, sys: 677 ms, total: 2.22 s

Wall time: 59.5 s

3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")

pl 可以很方便地應用單CPU、多CPU、單GPU、多GPU乃至多TPU上訓練模型。
以下幾種情況訓練耗時統計如下:
  • 單CPU: 2min 17s
  • 單GPU:  59.4 s
  • 4個GPU(dp模式): 1min
  • 4個GPU(ddp模式): 38.9 s
一般情況下,如果是單機多卡,建議使用 ddp模式,因為dp模式需要非常多的data和model傳輸,非常耗時。

%%time

#單CPU

!python3 mnist_cnn.py --gpus=

0

-----------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9790999889373779, 

'test_loss'

: 0.07223792374134064}

--------------------------------------------------------------------------------

Testing: 100%|████████████████████████████████| 313/313 [00:05<00:00, 55.95it/s]

CPU 

times

: user 2.67 s, sys: 740 ms, total: 3.41 s

Wall time: 2min 17s

%%time

#單GPU

!python3 mnist_cnn.py --gpus=

1

---------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9778000116348267, 

'test_loss'

: 0.06929327547550201}

--------------------------------------------------------------------------------

Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 171.04it/s]

CPU 

times

: user 1.83 s, sys: 488 ms, total: 2.32 s

Wall time: 1min 3s

%%time

#多GPU,dp模式(為公平比較,batch_size=32*4)

!python3 mnist_cnn.py --gpus=

4

 --strategy=

"dp"

 --batch_size=

128

------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9790999889373779, 

'test_loss'

: 0.06855566054582596}

--------------------------------------------------------------------------------

Testing: 100%|██████████████████████████████████| 79/79 [00:02<00:00, 38.55it/s]

CPU 

times

: user 1.2 s, sys: 553 ms, total: 1.75 s

Wall time: 1min

%%time

#多GPU,ddp模式

!python3 mnist_cnn.py --gpus=

4

 --strategy=

"ddp_find_unused_parameters_false"

---------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9732000231742859, 

'test_loss'

: 0.08606339246034622}

--------------------------------------------------------------------------------

Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:00, 85.79it/s]

CPU 

times

: user 784 ms, sys: 387 ms, total: 1.17 s

Wall time: 38.9 s

4,使用梯度累加(accumulate_grad_batches=6)

梯度累加就是累加多個batch的梯度,然後用累加的梯度更新一次引數,使用梯度累加相當於增大batch_size.
由於更新引數的計算量略大於簡單梯度求和的計算量(對於大部分最佳化器而言),使用梯度累加會讓速度略有提升。
  • 4個GPU(ddp模式): 38.9 s
  • 4個GPU(ddp模式)+梯度累加: 36.9 s

%%time

#多GPU,ddp模式, 考慮梯度累加

!python3 mnist_cnn.py --accumulate_grad_batches=

6

 --gpus=

4

 --strategy=

"ddp_find_unused_parameters_false"

----------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9603000283241272, 

'test_loss'

: 0.1400066614151001}

--------------------------------------------------------------------------------

Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:00, 89.10it/s]

CPU 

times

: user 749 ms, sys: 402 ms, total: 1.15 s

Wall time: 36.9 s

5,使用半精度(precision=16)

透過precision可以設定 double (64), float (32), bfloat16 ("bf16"), half (16) 精度的訓練。
預設是float(32) 標準精度,bfloat16 ("bf16")是混合精度。
如果選擇 half(16) 半精度,並同時增大batch_size為原來2倍, 通常訓練速度會提升3倍左右。

%%time 

#半精度

!python3 mnist_cnn.py --precision=

16

 --batch_size=

64

 --gpus=

1

6,自動搜尋最大batch_size(auto_scale_batch_size="power")

!python3 mnist_cnn.py --auto_scale_batch_size=

"power"

  --gpus=

1

4
『訓練漲分技巧』
pytorch_lightning 可以非常容易地支援以下訓練漲分技巧:
  • SWA(隨機引數平均): 呼叫pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging實現。
  • CyclicLR(學習率週期性排程策略): 設定 lr_scheduler 為 torch.optim.lr_scheduler.CyclicLR實現。
  • auto_lr_find最優學習率發現: 設定 pl.Trainer(auto_lr_find = True)實現。
參考論文:
  • Cyclical Learning Rates for Training Neural Networks 【https://arxiv.org/pdf/1506.01186.pdf】
  • Averaging Weights Leads to Wider Optima and Better Generalization【https://arxiv.org/abs/1803.05407】
我們將程式碼整理成如下形式,以便後續測試使用。

%%writefile mnist_cnn.py

import

 torch 

from

 torch 

import

 nn 

from

 argparse 

import

 ArgumentParser

import

 numpy 

as

 np 

import

 torchvision 

from

 torchvision 

import

 transforms 

as

 T

from

 torchvision.datasets 

import

 MNIST

from

 torch.utils.data 

import

 DataLoader,random_split

import

 pytorch_lightning 

as

 pl

from

 torchmetrics 

import

 Accuracy

#================================================================================
# 一,準備資料
#================================================================================

classMNISTDataModule(pl.LightningDataModule):
def__init__

(self, data_dir: str = 

"./minist/"

                 batch_size: int = 

32

,

                 num_workers: int =

4

,

                 pin_memory:bool =True)

:

        super().__init__()

        self.data_dir = data_dir

        self.batch_size = batch_size

        self.num_workers = num_workers

        self.pin_memory = pin_memory

defsetup(self, stage = None):

        transform = T.Compose([T.ToTensor()])

        self.ds_test = MNIST(self.data_dir, download=

True

,train=

False

,transform=transform)

        self.ds_predict = MNIST(self.data_dir, download=

True

, train=

False

,transform=transform)

        ds_full = MNIST(self.data_dir, download=

True

, train=

True

,transform=transform)

        ds_train, self.ds_val = random_split(ds_full, [

59000

1000

])

#為加速訓練,隨機取10000個

        indices = np.arange(

59000

)

        np.random.shuffle(indices)

        self.ds_train = torch.utils.data.dataset.Subset(

            ds_train,indices = indices[:

3000

]) 

deftrain_dataloader(self):
return

 DataLoader(self.ds_train, batch_size=self.batch_size,

                          shuffle=

True

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

defval_dataloader(self):
return

 DataLoader(self.ds_val, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

deftest_dataloader(self):
return

 DataLoader(self.ds_test, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

defpredict_dataloader(self):
return

 DataLoader(self.ds_predict, batch_size=self.batch_size,

                          shuffle=

False

, num_workers=self.num_workers,

                          pin_memory=self.pin_memory)

    @staticmethod
defadd_dataset_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=

False

)

        parser.add_argument(

'--batch_size'

, type=int, default=

32

)

        parser.add_argument(

'--num_workers'

, type=int, default=

8

)

        parser.add_argument(

'--pin_memory'

, type=bool, default=

True

)

return

 parser

#================================================================================
# 二,定義模型
#================================================================================

net = nn.Sequential(

    nn.Conv2d(in_channels=

1

,out_channels=

32

,kernel_size = 

3

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Conv2d(in_channels=

32

,out_channels=

64

,kernel_size = 

5

),

    nn.MaxPool2d(kernel_size = 

2

,stride = 

2

),

    nn.Dropout2d(p = 

0.1

),

    nn.AdaptiveMaxPool2d((

1

,

1

)),

    nn.Flatten(),

    nn.Linear(

64

,

32

),

    nn.ReLU(),

    nn.Linear(

32

,

10

)

)

classModel(pl.LightningModule):

def__init__

(self,net,

                 learning_rate=

1e-3

,

                 use_CyclicLR = False,

                 epoch_size=

500

)

:

        super().__init__()

        self.save_hyperparameters() 

#自動建立self.hparams

        self.net = net

        self.train_acc = Accuracy()

        self.val_acc = Accuracy()

        self.test_acc = Accuracy() 

defforward(self,x):

        x = self.net(x)

return

 x

#定義loss
deftraining_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

#定義各種metrics
deftraining_step_end(self,outputs):

        train_acc = self.train_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"train_acc"

,train_acc,prog_bar=

True

)

return

 {

"loss"

:outputs[

"loss"

].mean()}

#定義optimizer,以及可選的lr_scheduler
defconfigure_optimizers(self):

        optimizer = torch.optim.RMSprop(self.parameters(), lr=self.hparams.learning_rate)

ifnot

 self.hparams.use_CyclicLR:

return

 optimizer 
        max_lr = self.hparams.learning_rate

        base_lr = max_lr/

4.0

        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,

            base_lr=base_lr,max_lr=max_lr,

            step_size_up=

5

*self.hparams.epoch_size,cycle_momentum=

False

)

        self.print(

"set lr = "

+str(max_lr))

return

 ([optimizer],[scheduler])

defvalidation_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

defvalidation_step_end(self,outputs):

        val_acc = self.val_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"val_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)

        self.log(

"val_acc"

,val_acc,prog_bar=

True

,on_epoch=

True

,on_step=

False

)

deftest_step(self, batch, batch_idx):

        x, y = batch

        preds = self(x)

        loss = nn.CrossEntropyLoss()(preds,y)

return

 {

"loss"

:loss,

"preds"

:preds.detach(),

"y"

:y.detach()}

deftest_step_end(self,outputs):

        test_acc = self.test_acc(outputs[

'preds'

], outputs[

'y'

]).item()    

        self.log(

"test_acc"

,test_acc,on_epoch=

True

,on_step=

False

)

        self.log(

"test_loss"

,outputs[

"loss"

].mean(),on_epoch=

True

,on_step=

False

)

    @staticmethod
defadd_model_args(parent_parser):

        parser = ArgumentParser(parents=[parent_parser], add_help=

False

)

        parser.add_argument(

'--learning_rate'

, type=float, default=

7e-3

)

        parser.add_argument(

'--use_CyclicLR'

, type=bool, default=

False

)

return

 parser

#================================================================================
# 三,訓練模型
#================================================================================

defmain(hparams):

    pl.seed_everything(

1234

)
    data_mnist = MNISTDataModule(batch_size=hparams.batch_size,

                                 num_workers=hparams.num_workers)

    data_mnist.setup()

    epoch_size = len(data_mnist.ds_train)//data_mnist.batch_size
    model = Model(net,learning_rate=hparams.learning_rate,

                  use_CyclicLR = hparams.use_CyclicLR,

                  epoch_size=epoch_size)
    ckpt_callback = pl.callbacks.ModelCheckpoint(

        monitor=

'val_acc'

,

        save_top_k=

3

,

        mode=

'max'

    )
    early_stopping = pl.callbacks.EarlyStopping(monitor = 

'val_acc'

,

                   patience=

16

,

                   mode = 

'max'

)

    callbacks = [ckpt_callback,early_stopping]

if

 hparams.use_swa:

        callbacks.append(pl.callbacks.StochasticWeightAveraging())
    trainer = pl.Trainer.from_argparse_args( 

        hparams,

        max_epochs=

1000

,

        callbacks = callbacks) 

    print(

"hparams.auto_lr_find="

,hparams.auto_lr_find)

if

 hparams.auto_lr_find:

#搜尋學習率範圍

        lr_finder = trainer.tuner.lr_find(model,

          datamodule = data_mnist,

          min_lr=

1e-08

,

          max_lr=

1

,

          num_training=

100

,

          mode=

'exponential'

,

          early_stop_threshold=

4.0

          )

        lr_finder.plot() 

        lr = lr_finder.suggestion()

        model.hparams.learning_rate = lr 

        print(

"suggest lr="

,lr)

del

 model 
        hparams.learning_rate = lr

        model = Model(net,learning_rate=hparams.learning_rate,

                  use_CyclicLR = hparams.use_CyclicLR,

                  epoch_size=epoch_size)

#等價於
#trainer.tune(model,data_mnist)

    trainer.fit(model,data_mnist)

    train_result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path=

'best'

)

    val_result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path=

'best'

)

    test_result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path=

'best'

)
    print(

"train_result:\n"

)

    print(train_result)

    print(

"val_result:\n"

)

    print(val_result)

    print(

"test_result:\n"

)

    print(test_result)

if

 __name__ == 

"__main__"

:

    parser = ArgumentParser()

    parser.add_argument(

'--use_swa'

, default=

False

, type=bool)

    parser = MNISTDataModule.add_dataset_args(parser)

    parser = Model.add_model_args(parser)

    parser = pl.Trainer.add_argparse_args(parser)

    hparams = parser.parse_args()

    main(hparams)

1,SWA 隨機權重平均 (pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging)

  • 平凡方式訓練:test_acc = 0.9581000208854675
  • SWA隨機權重:test_acc = 0.963100016117096
#平凡方式訓練

!python3 mnist_cnn.py --gpus=

2

 --strategy=

"ddp_find_unused_parameters_false"

------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9581000208854675, 

'test_loss'

: 0.14859822392463684}

--------------------------------------------------------------------------------

#使用SWA隨機權重

!python3 mnist_cnn.py --gpus=

2

 --strategy=

"ddp_find_unused_parameters_false"

 --use_swa=

True

-----------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.963100016117096, 

'test_loss'

: 0.18146753311157227}

--------------------------------------------------------------------------------

2,CyclicLR學習率排程策略(torch.optim.lr_scheduler.CyclicLR)

  • 平凡方式訓練:test_acc = 0.9581000208854675
  • SWA隨機權重:test_acc = 0.963100016117096
  • SWA隨機權重 + CyClicLR學習率排程策略: test_acc = 0.9688000082969666

!python3 mnist_cnn.py --gpus=

2

 --strategy=

"ddp_find_unused_parameters_false"

 --use_swa=

True

 --use_CyclicLR=

True

------------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9688000082969666, 

'test_loss'

: 0.11470437049865723}

--------------------------------------------------------------------------------

3, 最優學習率搜尋(auto_lr_find=True)

  • 平凡方式訓練:test_acc = 0.9581000208854675
  • SWA隨機權重:test_acc = 0.963100016117096
  • SWA隨機權重 + CyClicLR學習率排程策略: test_acc = 0.9688000082969666
  • SWA隨機權重 + CyClicLR學習率排程策略 + 最優學習率搜尋:test_acc = 0.9693999886512756

!python3 mnist_cnn.py --gpus=

1

  --auto_lr_find=

True

 --use_swa=

True

 --use_CyclicLR=

True

---------------------------------------------------------------

DATALOADER:0 TEST RESULTS

{

'test_acc'

: 0.9693999886512756, 

'test_loss'

: 0.11024412512779236}

--------------------------------------------------------------------------------

Testing: 100%|███████████████████████████████| 313/313 [00:02<00:00, 137.85it/s]

以上。
萬水千山總是情,點個在看行不行?😋
技術交流群邀請函
△長按新增小助手
掃描二維碼新增小助手微信
請備註:姓名-學校/公司-研究方向
(如:小張-哈工大-對話系統)
即可申請加入自然語言處理/Pytorch等技術交流群

關於我們

MLNLP社群  機器學習演算法與自然語言處理 ) 是由國內外自然語言處理學者聯合構建的民間學術社群,目前已經發展為國內外知名自然語言處理社群,旗下包括  萬人頂會交流群、AI臻選匯、AI英才匯  以及  AI學術匯  等知名品牌,旨在促進機器學習,自然語言處理學術界、產業界和廣大愛好者之間的進步。
社群可以為相關從業者的深造、就業及研究等方面提供開放交流平臺。歡迎大家關注和加入我們。

相關文章