torch保存加载模型

2022/3/9 23:45:50

本文主要是介绍torch保存加载模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

目录
  • 三个核心函数
  • 状态字典定义
  • 只保存/加载模型参数(推荐做法)
  • 保存/加载整个模型
  • 断点训练checkpoint使用
  • 同一个文件中保存多个模型
  • 用一个模型的参数来初始化另一个不同模型
  • 不同设备保存/加载模型

三个核心函数

torch.save() 
torch.load()
torch.nn.Module.load_state_dict()

状态字典定义

状态字典本质上就是普通的python字典。

  • 对于具有可学习参数的网络层来说,状态字典的键就是网络层,值就是对应的参数张量。
    大概如下图所示,网络层的可学习参数包括权重和偏置等。

    当然batchnorm层也有需要保存的参数,比如running_mean。
  • 对于优化器对象也有自己的状态字典。其中包含了优化器状态信息和超参数。优化器的状态字典一般只有断点训练的时候才使用,毕竟推理也用不到优化器。

只保存/加载模型参数(推荐做法)

# 保存模型参数
torch.save(model.state_dict(), PATH)  
# 加载模型参数并用于推理
model = MyModel()
model.load_static_dict(torch.load(PATH))
model.eval()
  • torch.save()保存的文件后缀通常是 .pt 或 .pth
  • 保存模型参数的对象model和加载模型参数的对象model应该是同一个类的实例。
  • load_static_dict()方法的参数是一个字典,必须先用torch.load()把保存的参数转化成python字典。
  • 进行推理之前,必须先用model.eval()把dropout和BN层置为验证模式。

保存/加载整个模型

# 保存整个模型
torch.save(model, PATH)
# 加载整个模型
model = torch.load(PATH)
model.eval()

断点训练checkpoint使用

# 保存断点状态,保存的文件后缀一般是.tar。
torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'loss': loss,
  ...
}, PATH)

# 加载断点
model = MyModel()
optimizer = MyOptimizer()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
# model.eval() # 恢复断点之后直接推理也是可以的

同一个文件中保存多个模型

# 其实本质上跟checkpoint的使用是一样的
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)
# 加载多个模型,本质上跟checkpoint也是一样的,保存文件后缀名也是.tar
modelA = MyModel()
modelB = MyModel()
optimizerA = MyOptimizer()
optimizerB = MyOptimizer()

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

用一个模型的参数来初始化另一个不同模型

# 保存模型参数
torch.save(modelA.state_dict(), PATH)
# 加载模型参数
modelB = MyModel()
modelB.load_state_dict(torch.load(PATH), strict=False)
  • load_state_dict()方法中strict=False表示忽略不匹配的网络层,毕竟两个网络不一样

不同设备保存/加载模型

  • 保存时候没区别,反正都是保存到磁盘上
    torch.save(model.state_dict(), PATH)
  • 加载模型到cpu上
    device = torch.device('cpu')
    model = MyModel()
    model.load_state_dict(torch.load(PATH, map_location=device))
    
  • 加载模型到GPU上
    # 有点奇怪,为啥不用map_location参数,而要先加载再转移到GPU上
    device = torch.device('cuda')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH)
    model.to(device)
    


这篇关于torch保存加载模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程