GFPGAN源码分析—第六篇
2021/12/7 1:16:59
本文主要是介绍GFPGAN源码分析—第六篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
2021SC@SDUSC
源码:archs\gfpganv1_clean_arch.py
本篇主要分析gfpganv1_clean_arch.py下的
class GFPGANv1Clean(nn.Module)类_init_()方法
目录
class GFPGANv1Clean(nn.Module)
init()
(1)channels的设置
(2)调用torch.nn.Conv2d()创建了一层卷积神经网络
(3)下采样(downsample)
(4)上采样(upsample)
(5)全连接层
(6)创建self.stylegan_decoder
(7)如果decoder_load_path不为空则读取
(8)for SFT(SFT layer)
class GFPGANv1Clean(nn.Module)
继承自nn.Module类,使得我们可以使用很多现成的类,比如本类中使用的Conv2d以及RelU激活函数等等。
init()
参数:
self, out_size, num_style_feat=512, channel_multiplier=1, decoder_load_path=None, fix_decoder=True, # for stylegan decoder num_mlp=8, input_is_latent=False, different_w=False, narrow=1, sft_half=False
在class GFPGANer()-init()中被调用时:
self.gfpgan = GFPGANv1Clean( out_size=512, num_style_feat=512, channel_multiplier=channel_multiplier, decoder_load_path=None, fix_decoder=False, num_mlp=8, input_is_latent=True, different_w=True, narrow=1, sft_half=True)
(1)channels的设置
实际调用的时候narrow=1,
channels保存了经过convolution层后的输出的通道数
unet_narrow = narrow * 0.5 channels = { '4': int(512 * unet_narrow), '8': int(512 * unet_narrow), '16': int(512 * unet_narrow), '32': int(512 * unet_narrow), '64': int(256 * channel_multiplier * unet_narrow), '128': int(128 * channel_multiplier * unet_narrow), '256': int(64 * channel_multiplier * unet_narrow), '512': int(32 * channel_multiplier * unet_narrow), '1024': int(16 * channel_multiplier * unet_narrow) }
(2)调用torch.nn.Conv2d()搭建卷积神经网络
#out_size=512,so log_size=9 self.log_size = int(math.log(out_size, 2)) #first_out_size = 512 first_out_size = 2 ** (int(math.log(out_size, 2))) #channels['512']=32*2*0.5=32 self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
在这里介绍一下nn.Conv2d()的几个参数
in_channels: int,#输入的通道数目【必选】 out_channels: int,# 输出的通道数目【必选】 kernel_size: _size_2_t,#卷积核的大小,类型为int(方形边长) 或者元组(长和宽)【必选】 stride: _size_2_t = 1,#步长 padding: Union[str, _size_2_t] = 0,#边界增益,可以控制输出结果的尺寸 dilation: _size_2_t = 1,#控制卷积核之间的间距 groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', # TODO: refine this type device=None, dtype=None
那么可以得知
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
#实际上是传入通道为3(RGB)的输入,使用边长为1的卷积核,最后获得通道为32的输出
#由于卷积核边长为1,我们输入与输入的图片大小仍然保持一致,但增加了通道数
(3)下采样(downsample)
可以看到实际上是调用ResBlock做了下采样
# 输入图片的通道数(实际为32) in_channels = channels[f'{first_out_size}'] #创建ModuleList容器 self.conv_body_down = nn.ModuleList() # i从self.log_size(9)->3 :7次循环 for i in range(self.log_size, 2, -1): out_channels = channels[f'{2 ** (i - 1)}'] #调用ResBlock残差网络做下采样,并将该module添加到设置的ModuleList self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) #这一层的输出管道数作为下一层输入的管道数 in_channels = out_channels
介绍一下nn.ModuleList()
nn.ModuleList,它是一个储存不同module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。
#注意nn.ModuleList则没有实现内部forward函数,所以需要手动实现
最后一层卷积层的搭建:
#最终输出通道数为channels['4']=256,使用边长为3的卷积核,步长为1,padding为1,保证维度不变 self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
(4)上采样(upsample)
#输入通道数为channels['4']=256,即下采样的输出的通道数 in_channels = channels['4'] #创建ModuleList容器 self.conv_body_up = nn.ModuleList() # i从3->self.log_size(9) :7次循环 for i in range(3, self.log_size + 1): # 定义输出的通道数 out_channels = channels[f'{2 ** i}'] # 调用带有上采样ResBlock残差网络,并将该module添加到设置的ModuleList self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) #这一层的输出管道数作为下一层输入的管道数 in_channels = out_channels
(5)全连接层
根据传入的参数different_w,选择每个输出样本的大小,并搭建相应的全连接层。
if different_w: #16*512=8192 linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat print(linear_out_channel) else: #512 linear_out_channel = num_style_feat #全连接层size of each input sample:4096,size of each output sample:8192 self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
(6)创建self.stylegan_decoder
self.stylegan_decoder = StyleGAN2GeneratorCSFT( out_size=out_size, num_style_feat=num_style_feat, num_mlp=num_mlp, channel_multiplier=channel_multiplier, narrow=narrow, sft_half=sft_half)
(7)如果decoder_load_path不为空则读取
if decoder_load_path: self.stylegan_decoder.load_state_dict( torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) if fix_decoder: for name, param in self.stylegan_decoder.named_parameters(): param.requires_grad = False
(8)for SFT(SFT layer)
#ModuleList self.condition_scale = nn.ModuleList() self.condition_shift = nn.ModuleList() # i从3->self.log_size(9) :7次循环 for i in range(3, self.log_size + 1): # 定义输出的通道数 out_channels = channels[f'{2 ** i}'] #输出通道数是否减半 if sft_half: sft_out_channels = out_channels else: sft_out_channels = out_channels * 2 #使用nn.Sequential搭建网络,并添加到ModuleList self.condition_scale.append( nn.Sequential( #卷积核边长为3,步长为1,输出与输出保持相同维度 nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) self.condition_shift.append( nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
nn.Sequential是一个有序的容器,其中传入的是构造器类(各种用来处理input的类),最终input会被Sequential中的构造器依次执行。
这篇关于GFPGAN源码分析—第六篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-11有哪些好用的家政团队管理工具?
- 2025-01-11营销人必看的GTM五个指标
- 2025-01-11办公软件在直播电商前期筹划中的应用与推荐
- 2025-01-11提升组织效率:上级管理者如何优化跨部门任务分配
- 2025-01-11酒店精细化运营背后的协同工具支持
- 2025-01-11跨境电商选品全攻略:工具使用、市场数据与选品策略
- 2025-01-11数据驱动酒店管理:在线工具的核心价值解析
- 2025-01-11cursor试用出现:Too many free trial accounts used on this machine 的解决方法
- 2025-01-11百万架构师第十四课:源码分析:Spring 源码分析:深入分析IOC那些鲜为人知的细节|JavaGuide
- 2025-01-11不得不了解的高效AI办公工具API