torch保存加载模型
2022/3/9 23:45:50
本文主要是介绍torch保存加载模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
目录- 三个核心函数
- 状态字典定义
- 只保存/加载模型参数(推荐做法)
- 保存/加载整个模型
- 断点训练checkpoint使用
- 同一个文件中保存多个模型
- 用一个模型的参数来初始化另一个不同模型
- 不同设备保存/加载模型
三个核心函数
torch.save() torch.load() torch.nn.Module.load_state_dict()
状态字典定义
状态字典本质上就是普通的python字典。
- 对于具有可学习参数的网络层来说,状态字典的键就是网络层,值就是对应的参数张量。
大概如下图所示,网络层的可学习参数包括权重和偏置等。
当然batchnorm层也有需要保存的参数,比如running_mean。 - 对于优化器对象也有自己的状态字典。其中包含了优化器状态信息和超参数。优化器的状态字典一般只有断点训练的时候才使用,毕竟推理也用不到优化器。
只保存/加载模型参数(推荐做法)
# 保存模型参数 torch.save(model.state_dict(), PATH)
# 加载模型参数并用于推理 model = MyModel() model.load_static_dict(torch.load(PATH)) model.eval()
- torch.save()保存的文件后缀通常是 .pt 或 .pth
- 保存模型参数的对象model和加载模型参数的对象model应该是同一个类的实例。
- load_static_dict()方法的参数是一个字典,必须先用torch.load()把保存的参数转化成python字典。
- 进行推理之前,必须先用model.eval()把dropout和BN层置为验证模式。
保存/加载整个模型
# 保存整个模型 torch.save(model, PATH)
# 加载整个模型 model = torch.load(PATH) model.eval()
断点训练checkpoint使用
# 保存断点状态,保存的文件后缀一般是.tar。 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': loss, ... }, PATH)
# 加载断点 model = MyModel() optimizer = MyOptimizer() checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.train() # model.eval() # 恢复断点之后直接推理也是可以的
同一个文件中保存多个模型
# 其实本质上跟checkpoint的使用是一样的 torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)
# 加载多个模型,本质上跟checkpoint也是一样的,保存文件后缀名也是.tar modelA = MyModel() modelB = MyModel() optimizerA = MyOptimizer() optimizerB = MyOptimizer() checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
用一个模型的参数来初始化另一个不同模型
# 保存模型参数 torch.save(modelA.state_dict(), PATH)
# 加载模型参数 modelB = MyModel() modelB.load_state_dict(torch.load(PATH), strict=False)
- load_state_dict()方法中strict=False表示忽略不匹配的网络层,毕竟两个网络不一样
不同设备保存/加载模型
- 保存时候没区别,反正都是保存到磁盘上
torch.save(model.state_dict(), PATH)
- 加载模型到cpu上
device = torch.device('cpu') model = MyModel() model.load_state_dict(torch.load(PATH, map_location=device))
- 加载模型到GPU上
# 有点奇怪,为啥不用map_location参数,而要先加载再转移到GPU上 device = torch.device('cuda') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH) model.to(device)
这篇关于torch保存加载模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享