关于跑图像去雾算法DCPDN的教程及Bug解决
2021/4/25 22:27:24
本文主要是介绍关于跑图像去雾算法DCPDN的教程及Bug解决,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
前情提要
最近刚刚开始图像去雾方面的研究,自然少不了阅读这一领域的经典文献和GitHub源码。DCPDN是其中比较很有价值的一篇,在阅读文献过程中,希望跑通它的代码,结合代码来帮助我理解这一算法的实现原理。但是我在配置环境跑程序的过程中出现了许多问题,花费了许多时间和精力解决了其中部分问题,因此想在此记录下来,同时也希望对同样遇到这些问题的你有所帮助,谢谢~
论文:Densely Connected Pyramid Dehazing Network
github源码:https://github.com/hezhangsprinter/DCPDN
参考的相关博客:1. 一步一步教你跑DCPDN深度学习去雾网络
2. DCPDN项目
代码运行环境
- Ubuntu 18.04.3
- python 3.6
- torch 0.3.1
- torchvision 0.2.1
直接把我配环境的命令行语句贴出来吧
- 新建conda环境
conda create -n xxx(你的环境名) python=3.6
- 进入新建的环境(激活环境)
conda activate xxx(你的环境名)
- 开始安装各种依赖的包
pip install https://download.pytorch.org/whl/cu90/torch-0.3.1-cp36-cp36m-linux_x86_64.whl pip install torchvision==0.2.1 pip install h5py pip install scipy
(没记错的话应该就是上面这些,比较重要的是torch版本和torchvision版本,因为作者的代码是基于早期的低版本,由于新版本做了一些改动,如果不按低版本来装的话,会遇到更多麻烦的问题,亲身经历。等自己研究透了,看看能否高版本的复现一下?哈哈哈)
原始代码可能使用的是Python2版本,因此某些语法与现在的Python3不兼容,(Python2使用<>
作为“不等于”的符号,而Python3使用!=
)需要修改一下,分别位于train.py
文件的第312行、329行、353行和366行。
问题1:Missing key(s) in state_dict: xxxxxxxxxxxxx; Unexpected key(s) in state_dict: xxxxxxxxxxxxx
报错原因: 预训练模权重的字典关键字与所创建的网络模型的字典关键字不匹配。(简单来说,我们的模型要使用现有的预训练权重来进行参数的初始化,这个过程需要两者各层级的网络名称相对应,否则就会出现上述错误。)
问题出在train.py文件的第124行:
if opt.netG != '': netG.load_state_dict(torch.load(opt.netG))
opt.netG是预训练权重文件netG_epoch_8.pth的所在路径,torch.load(opt.netG)
是加载这一权重文件。这个权重文件在以前训练的时候可能还是采用旧的字典关键字,如:‘norm.1’, ‘relu.1’, ‘conv.1’, ‘norm.2’, ‘relu.2’, ‘conv.2’,但是现在的网络模型在创建时已经不再允许使用“.”了,所以需要修改预训练权重的关键字,使其与我们的网络匹配。
通过正则修改,将上面的代码修改成以下内容:(参考torchvision.models.densenet中的做法)
if opt.netG != '': checkpoint = torch.load(opt.netG) pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.' r'(?:weight|bias|running_mean|running_var))$') for key in list(checkpoint.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) checkpoint[new_key] = checkpoint[key] del checkpoint[key] netG.load_state_dict(checkpoint)
经过修改后,这一段代码就能顺利执行了。
问题2:"python3.6/site-packages/torch/utils/data/dataloader.py", line 271, in __next__ raise StopIteration
关于这个问题,上面参考的博客2中作出了解答。因为博客1的作者训练网络时使用以下的命令:
python train.py --dataroot ./facades/train512 --valDataroot ./facades/test512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth
其中,--valDataroot
传入的是./facades/test512
这个路径,但是源代码的作者并没有提供这一文件,只有./facades/val512
,因此把命令改成:
python train.py --dataroot ./facades/train512 --valDataroot ./facades/val512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth
就能成功解决这个问题。
最后,再次感谢上面两位博主的博客~
这篇关于关于跑图像去雾算法DCPDN的教程及Bug解决的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-15Typescript 类型教程:轻松入门与实践指南
- 2024-11-15AntDesign-icons项目实战:新手入门教程
- 2024-11-14用Scratch编写语言模型:爪爪(Clawed)式简易教程
- 2024-11-14用大型语言模型在Amazon Bedrock上分类Jira工单
- 2024-11-14从数据到行动:亚马逊Bedrock代理如何自动化复杂工作流
- 2024-11-14Databricks与优化后的Snowflake性能大比拼
- 2024-11-14亚马逊 Inspector 解析:提升您的 AWS 负载安全的利器
- 2024-11-14揭秘VS Code for Web - Azure:轻松开发云端应用的新利器
- 2024-11-14揭秘指南:如何让Databricks中的数据为最终用户所用
- 2024-11-14OpenTelemetry扩展进入CI/CD可观测性领域