GFPGAN源码分析—第八篇
2021/12/27 1:07:34
本文主要是介绍GFPGAN源码分析—第八篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
2021SC@SDUSC
源码:
models\init.py
models\gfpgan_model.py
本篇主要分析init.py与models\gfpgan_model.py下的
class GFPGANModel(BaseModel) 类init(self, opt) 方法
目录
init.py
gfpgan_model.py
class GFPGANModel(BaseModel)
init(self, opt)
init_training_settings(self)
init.py
自动扫描和导入注册表的模型模块
#在models文件夹下扫描所有以 '_model.py' 结尾的文件 model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # 导入所有模型模块 _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
那么实际上就是导入models文件夹下gfpgan_model.py文件,接下来我们来看一下
gfpgan_model.py
本文件中只包含GFPGANModel(BaseModel)一个类
创建了一个MODEL_REGISTRY对象,并在类定义的时候用装饰器装饰它,以装饰器的形式调用MODEL_REGISTRY类的register函数
@MODEL_REGISTRY.register() class GFPGANModel(BaseModel): """GFPGAN model for <Towards real-world blind faces restoratin with generative facial prior>"""
class GFPGANModel(BaseModel)
基于生成性人脸先验信息的真实盲脸修复 的 GFPGAN 模型
init(self, opt)
简单看一下代码
super(GFPGANModel, self).__init__(opt) self.idx = 0 # 网络定义 self.net_g = build_network(opt['network_g']) self.net_g = self.model_to_device(self.net_g) self.print_network(self.net_g) # 读取预训练的模型 load_path = self.opt['path'].get('pretrain_network_g', None) #如果路径不为空 if load_path is not None: param_key = self.opt['path'].get('param_key_g', 'params') self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) if self.is_train: self.init_training_settings()
在读取预训练的模型时,实际上就是从train_gfpgan_v1.yml配置文件中读取到相应的参数的数值与路径。
init_training_settings(self)
初始化训练设置
1.读取opt['train']
train_opt = self.opt['train']
2.定义net_d
#构建网络 self.net_d = build_network(self.opt['network_d']) #将模型放到gpu(cuda)上 self.net_d = self.model_to_device(self.net_d) self.print_network(self.net_d) # 读取与训练好的模型 load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
3.定义net_g
# net_g_ema 仅用于在一个GPU上测试并保存 # 不需要使用DistributedDataParallel进行包装 self.net_g_ema = build_network(self.opt['network_g']).to(self.device) # 读取预训练模型 load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') else: self.model_ema(0) # copy net_g weight self.net_g.train() self.net_d.train() self.net_g_ema.eval()
根据配置文件:net_g读取预训练模型为arcface_resnet18.pth
4.面部组件网络
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): self.use_facial_disc = True else: self.use_facial_disc = False if self.use_facial_disc: # left eye self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) self.print_network(self.net_d_left_eye) load_path = self.opt['path'].get('pretrain_network_d_left_eye') if load_path is not None: self.load_network(self.net_d_left_eye, load_path, True, 'params') # right eye self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) self.print_network(self.net_d_right_eye) load_path = self.opt['path'].get('pretrain_network_d_right_eye') if load_path is not None: self.load_network(self.net_d_right_eye, load_path, True, 'params') # mouth self.net_d_mouth = build_network(self.opt['network_d_mouth']) self.net_d_mouth = self.model_to_device(self.net_d_mouth) self.print_network(self.net_d_mouth) load_path = self.opt['path'].get('pretrain_network_d_mouth') if load_path is not None: self.load_network(self.net_d_mouth, load_path, True, 'params') self.net_d_left_eye.train() self.net_d_right_eye.train() self.net_d_mouth.train() # ----------- 定义面部组件的 gan loss ----------- # self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
5.定义损失
if train_opt.get('pixel_opt'): self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) else: self.cri_pix = None if train_opt.get('perceptual_opt'): self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) else: self.cri_perceptual = None # pyramid loss, component style loss, identity loss 都使用L1损失 self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) # gan loss (wgan) self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
6.identity loss的定义
if 'network_identity' in self.opt: self.use_identity = True else: self.use_identity = False if self.use_identity: # 定义 identity network self.network_identity = build_network(self.opt['network_identity']) self.network_identity = self.model_to_device(self.network_identity) self.print_network(self.network_identity) load_path = self.opt['path'].get('pretrain_network_identity') if load_path is not None: self.load_network(self.network_identity, load_path, True, None) self.network_identity.eval() for param in self.network_identity.parameters(): param.requires_grad = False # 正则化权重 self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator self.net_d_iters = train_opt.get('net_d_iters', 1) self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) self.net_d_reg_every = train_opt['net_d_reg_every'] # 设置优化器和调度程序 self.setup_optimizers() self.setup_schedulers()
这篇关于GFPGAN源码分析—第八篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23Springboot应用的多环境打包入门
- 2024-11-23Springboot应用的生产发布入门教程
- 2024-11-23Python编程入门指南
- 2024-11-23Java创业入门:从零开始的编程之旅
- 2024-11-23Java创业入门:新手必读的Java编程与创业指南
- 2024-11-23Java对接阿里云智能语音服务入门详解
- 2024-11-23Java对接阿里云智能语音服务入门教程
- 2024-11-23JAVA对接阿里云智能语音服务入门教程
- 2024-11-23Java副业入门:初学者的简单教程
- 2024-11-23JAVA副业入门:初学者的实战指南