0705-深度网络模型持久化
2021/5/1 10:28:44
本文主要是介绍0705-深度网络模型持久化,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
0705-深度网络模型持久化
目录- 一、持久化概述
- 二、tensor 对象的保存和加载
- 三、Module 对象的保存和加载
- 四、Optimizer 对象的保存和加载
- 五、所有对象集合的保存和加载
- 六、第七章总结
pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html
一、持久化概述
在 torch 中,以下对象可以持久化到硬盘,并可以通过相应的方法把这些对象持久化到内存中:
- Tensor
- Variable
- nn.Module
- Optimizer
上述对象本质上最后都是保存为 Tensor。并且 Tensor 的保存和加载非常简单,使用 t.save
和 t.load
即可。
在 save/load 时可指定使用的 pickle 模块,在 load 时还可以把 GPU tensor 映射到 CPU 或者其他 GPU 上。
我们可以通过 t.save(obj, file_name)
保存任意可序列化的对象,然后通过 obj=t.load(file_name)
方法加载保存的数据。
对于 Module 和 Optimizer 对象,建议保存为对应的 state_dict,而不是直接保存整个 Module/Optimizer 对象。Optimizer 对象保存的是参数和动量信息,通过加载之前的动量信息,能够很有效地减少模型震荡。
二、tensor 对象的保存和加载
import torch as t a = t.Tensor(3, 4) if t.cuda.is_available(): a = a.cuda(1) # 把 a 转为 GPU1 上的 tensor t.save(a, 'a.pth') # 加载为 b,存储于 GPU1 上(因为保存时 tensor 就在 GPU1 上) b = t.load('a.pth') # 加载为 c,存储于 CPU c = t.load('a.pth', map_location=lambda storage, loc: storage) # 加载为 d,存储于 GPU0 上 d = t.load('a.pth', map_location={'cuda:1': 'cuda:0'})
三、Module 对象的保存和加载
t.set_default_tensor_type('torch.FloatTensor') from torchvision.models import AlexNet model = AlexNet() # module 的 state_dict 是一个字典 model.state_dict().keys() t.save(model.state_dict(), 'alexnet.pth') model.load_state_dict(t.load('alexnet.pth'))
<All keys matched successfully>
四、Optimizer 对象的保存和加载
optimizer = t.optim.Adam(model.parameters(), lr=0.1) t.save(optimizer.state_dict(), 'optimizer.pth') optimizer.load_state_dict(t.load('optimizer.pth'))
五、所有对象集合的保存和加载
all_data = dict(optimizer=optimizer.state_dict(), model=model.state_dict(), info=u'模型和优化器的所有参数') t.save(all_data, 'all.pth') all_data = t.load('all.pth') all_data.keys()
dict_keys(['optimizer', 'model', 'info'])
六、第七章总结
本章介绍了 torch 的很多工具模块,主要涉及数据加载、可视化和 GPU 加速相关的内容,合理地使用这些模块可以极大地提升我们的编码效率。
这篇关于0705-深度网络模型持久化的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-07-02springboot项目无法注册到nacos-icode9专业技术文章分享
- 2024-06-26结对编程到底难不难?答案在这里
- 2024-06-19《2023版Java工程师》课程升级公告
- 2024-06-15matplotlib作图不显示3D图,怎么办?
- 2024-06-1503-Loki 日志监控
- 2024-06-1504-让LLM理解知识 -Prompt
- 2024-06-05做软件测试需要懂代码吗?
- 2024-06-0514-ShardingSphere的分布式主键实现
- 2024-06-03为什么以及如何要进行架构设计权衡?
- 2024-05-31全网首发第二弹!软考2024年5月《软件设计师》真题+解析+答案!(11-20题)