Home pytorch中如何持久化保存和加载模型
Post
Cancel

pytorch中如何持久化保存和加载模型

pytorch中保存和加载离线模型是常见的操作,相信很多读者和我一样平时多是直接调用方法,少有关心内部实现。这种习惯往往会使得对方法的理解不透彻进而印象不深,本文将深如的探讨下torch.save()torch.load()方法,以便加深印象和理解。

几个核心知识点:

  • torch.save:基于pickle包将Module,tensor,dict等类型的对象序列化到本地文件中。值得注意的是Pytorch1.6之后,torch.save使用了压缩格式的存储方式。所以如果要使用老的格式保存需要显示的设置:_use_new_zipfile_serialization=False,而load方法不受影响。

  • torch.load:基于pickle包将保存在磁盘的持久化文件发序列化为内存中的对象

  • torch.nn.Module.load_state_dict:用于加载反序列化之后的state_dict对象。Pytorch模型(Module或Optim)对象中的参数(注意是能够学习到的参数)都采用的是Python中字典类型存储,而字典中的key为每一层的名字,value为具体的参数。比如如下代码所示:

1
2
3
4
5
6
7
8
9
10
    conv1.weight     torch.Size([6, 3, 5, 5])
    conv1.bias   torch.Size([6])
    conv2.weight     torch.Size([16, 6, 5, 5])
    conv2.bias   torch.Size([16])
    fc1.weight   torch.Size([120, 400])
    fc1.bias     torch.Size([120])
    fc2.weight   torch.Size([84, 120])
    fc2.bias     torch.Size([84])
    fc3.weight   torch.Size([10, 84])
    fc3.bias     torch.Size([10])

几个Tips:

  • torch.save是一个非常好用的方法,在训练过程中记得多多的保存模型的当前状态。这个习惯会为调参带来极大的方便(血的教训)。

  • torch.save保存的文件名后缀通常是.pt或者.pth,使用这种约定俗成的习惯会方便看懂该文件的用途。

  • 记住必须在torch.save之前调用model.eval()方法,否则保存的不仅仅是模型的参数,这将使得每次的推理结果可能会不同(此时模型参数不定)。

  • load_state_dict()接受的参数是字典类型,即model.state_dict(),不要直接输入模型参数的路径PATH.

  • 由于state_dict是字典类型,如果在load_state_dict过程中出现key不匹配的问题,那么需要将strict参数设置为False,这样就不会加载不匹配部分的key对应的参数。

  • model.state_dict()在训练过程中会一直更新,所以在得到最好的结果的时候记得立即持久化,或者使用deepcopy(model.state_dict())拷贝。如果只是best_state_dict=model.state_dict(),那么best_state_dict将会在训练过程中改变。

torch.save和torch.load

如果不使用model.state_dict(),而直接保存整个模型可以直接调用如下方法:

1
torch.save(model, PATH)

重新加载则调用:

1
2
model = torch.load(PATH)
model.eval()

使用本组方法需要注意:

  • torch.save不会保存Module类,而是只保存类的路径。在torch.load时根据保存的路径加载Module类,所以是非常不灵活的。
  • 保存的文件后缀使用约定俗成的.pt或者.pth方法。
  • 在torch.load之后需要执行model.eval()方法,以保证推理结果的一致性。

TorchScript

如果使用TorchScript的方式保存和使用模型,则使用如下方法:

1
2
3
# 持久化模型
model_scripted = torch.jit.script(model) 
model_scripted.save('model_scripted.pt') 
1
2
3
# 加载模型
model = torch.jit.load('model_scripted.pt')
model.eval()

使用本组方法需要注意:

  • 保存的文件后缀使用约定俗成的.pt或者.pth方法。
  • 在torch.jit.load之后需要执行model.eval()方法,以保证推理结果的一致性。

保存训练状态

如前述所述,torch.save保存的是python的dict类型,因而我们也可以利用这个特点来保存模型训练过程中的torch.Optim和loss的状态。如下所示:

1
2
3
4
5
6
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)

而加载则利用:

1
2
3
4
5
6
7
8
model = ModelClass(*args, **kwargs)
optimizer = OptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

使用本组方法需要注意:

  • 使用.tar后缀名和单独的模型权重文件作为区分
  • 在load_state_dict之后,记得调用model.eval()方法之后进行推理或调用model.train()方法之后进行训练

GPU or CPU

模型可以加载在内存中也可以加载在GPU中,这就需要将state_dict拷贝到和module相同的地方(内存orGPU)才能进行赋值。torch.load通过map_location参数可以指定具体的地方。

1
2
3
4
model.cuda()
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.cpu()
model.load_state_dict(torch.load(PATH, map_location="cpu"))

特别注意torch.nn.DataParallel

torch.nn.DataParallel方法使用wrapper的方式使得module能够并行,因而需要将module单独提取出来。

1
torch.save(model.module.state_dict(), PATH)

[1]https://pytorch.org/tutorials/beginner/saving_loading_models.html [2]https://pytorch.org/docs/stable/generated/torch.save.html#torch.save [3]https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict [4]https://docs.python.org/3/library/io.html

This post is licensed under CC BY 4.0 by the author.