pytorch中保存和加载离线模型是常见的操作,相信很多读者和我一样平时多是直接调用方法,少有关心内部实现。这种习惯往往会使得对方法的理解不透彻进而印象不深,本文将深如的探讨下torch.save()
和torch.load()
方法,以便加深印象和理解。
几个核心知识点:
-
torch.save:基于pickle包将Module,tensor,dict等类型的对象序列化到本地文件中。值得注意的是Pytorch1.6之后,torch.save使用了压缩格式的存储方式。所以如果要使用老的格式保存需要显示的设置:
_use_new_zipfile_serialization=False
,而load方法不受影响。 -
torch.load:基于pickle包将保存在磁盘的持久化文件发序列化为内存中的对象
-
torch.nn.Module.load_state_dict:用于加载反序列化之后的
state_dict
对象。Pytorch模型(Module或Optim)对象中的参数(注意是能够学习到的参数)都采用的是Python中字典类型存储,而字典中的key为每一层的名字,value为具体的参数。比如如下代码所示:
1
2
3
4
5
6
7
8
9
10
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
几个Tips:
-
torch.save是一个非常好用的方法,在训练过程中记得多多的保存模型的当前状态。这个习惯会为调参带来极大的方便(血的教训)。
-
torch.save保存的文件名后缀通常是.pt或者.pth,使用这种约定俗成的习惯会方便看懂该文件的用途。
-
记住必须在torch.save之前调用model.eval()方法,否则保存的不仅仅是模型的参数,这将使得每次的推理结果可能会不同(此时模型参数不定)。
-
load_state_dict()接受的参数是字典类型,即model.state_dict(),不要直接输入模型参数的路径PATH.
-
由于state_dict是字典类型,如果在load_state_dict过程中出现key不匹配的问题,那么需要将strict参数设置为False,这样就不会加载不匹配部分的key对应的参数。
-
model.state_dict()在训练过程中会一直更新,所以在得到最好的结果的时候记得立即持久化,或者使用deepcopy(model.state_dict())拷贝。如果只是best_state_dict=model.state_dict(),那么best_state_dict将会在训练过程中改变。
torch.save和torch.load
如果不使用model.state_dict(),而直接保存整个模型可以直接调用如下方法:
1
torch.save(model, PATH)
重新加载则调用:
1
2
model = torch.load(PATH)
model.eval()
使用本组方法需要注意:
- torch.save不会保存Module类,而是只保存类的路径。在torch.load时根据保存的路径加载Module类,所以是非常不灵活的。
- 保存的文件后缀使用约定俗成的.pt或者.pth方法。
- 在torch.load之后需要执行model.eval()方法,以保证推理结果的一致性。
TorchScript
如果使用TorchScript的方式保存和使用模型,则使用如下方法:
1
2
3
# 持久化模型
model_scripted = torch.jit.script(model)
model_scripted.save('model_scripted.pt')
1
2
3
# 加载模型
model = torch.jit.load('model_scripted.pt')
model.eval()
使用本组方法需要注意:
- 保存的文件后缀使用约定俗成的.pt或者.pth方法。
- 在torch.jit.load之后需要执行model.eval()方法,以保证推理结果的一致性。
保存训练状态
如前述所述,torch.save保存的是python的dict类型,因而我们也可以利用这个特点来保存模型训练过程中的torch.Optim和loss的状态。如下所示:
1
2
3
4
5
6
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, PATH)
而加载则利用:
1
2
3
4
5
6
7
8
model = ModelClass(*args, **kwargs)
optimizer = OptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
使用本组方法需要注意:
- 使用.tar后缀名和单独的模型权重文件作为区分
- 在load_state_dict之后,记得调用model.eval()方法之后进行推理或调用model.train()方法之后进行训练
GPU or CPU
模型可以加载在内存中也可以加载在GPU中,这就需要将state_dict拷贝到和module相同的地方(内存orGPU)才能进行赋值。torch.load通过map_location参数可以指定具体的地方。
1
2
3
4
model.cuda()
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.cpu()
model.load_state_dict(torch.load(PATH, map_location="cpu"))
特别注意torch.nn.DataParallel
torch.nn.DataParallel方法使用wrapper的方式使得module能够并行,因而需要将module单独提取出来。
1
torch.save(model.module.state_dict(), PATH)
[1]https://pytorch.org/tutorials/beginner/saving_loading_models.html [2]https://pytorch.org/docs/stable/generated/torch.save.html#torch.save [3]https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict [4]https://docs.python.org/3/library/io.html