GFPGAN源码分析—第五篇
2021/12/7 1:18:22
本文主要是介绍GFPGAN源码分析—第五篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
2021SC@SDUSC
源码:archs\gfpganv1_clean_arch.py
本篇主要分析gfpganv1_clean_arch.py下的以下两个类
class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):StyleGan
class ResBlock(nn.Module):残差网络
目录
class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):
_init_( )
forward( )
latents with Style MLP layer">(1) style codes -> latents with Style MLP layer
(2)noise
(3) style truncation
(4)get style latent with injection
class ResBlock(nn.Module):
_init_( )
forward( )
class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):
继承了StyleGAN2GeneratorClean类
_init_( )
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False): super(StyleGAN2GeneratorCSFT, self).__init__( out_size, num_style_feat=num_style_feat, num_mlp=num_mlp, channel_multiplier=channel_multiplier, narrow=narrow) self.sft_half = sft_half
forward( )
参数:
(self, styles,#(list[Tensor]): Sample codes of styles. conditions, input_is_latent=False,#(bool): Whether input is latent style. noise=None,#(Tensor | None): Input noise or None. randomize_noise=True,#(bool): Randomize noise, used when 'noise' is False. truncation=1, truncation_latent=None, inject_index=None,#The injection index for mixing noise. return_latents=False)# Whether to return style latents.
(1) style codes -> latents with Style MLP layer
if not input_is_latent: styles = [self.style_mlp(s) for s in styles]
(2)noise
if noise is None: if randomize_noise: noise = [None] * self.num_layers # for each style conv layer else: # use the stored noise noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
(3) style truncation
if truncation < 1: style_truncation = [] for style in styles: style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) styles = style_truncation
(4)get style latent with injection
if len(styles) == 1: inject_index = self.num_latent if styles[0].ndim < 3: # repeat latent code for all the layers latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: # used for encoder with different latent code for each layer latent = styles[0] elif len(styles) == 2: # mixing noises if inject_index is None: inject_index = random.randint(1, self.num_latent - 1) latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) latent = torch.cat([latent1, latent2], 1)
class ResBlock(nn.Module):
带有上/下采样的残差网络
Residual block with upsampling/downsampling
实际的一个单元(unit)即:con-relu-padding-con-relu
resNet本质上是为网络加了一个shortcut,相当于部分层数变成了一个直连接,从而防止出现升高神经网络层数反而效果变差的情况。所以需要让输入与输出的shape,channels都要保持一致。至于是否要退化部分网络,是由网络根据训练效果自身去选择的。
_init_( )
def __init__(self, in_channels, out_channels, mode='down'): super(ResBlock, self).__init__() #输入的通道数:in_channels #输出的通道数:out_channels #搭建卷积神经网络:卷积核[in_channels,3,3];步长为1;使用padding=1,边界增加一圈 #padding=1保持输出与输入大小保持一致 self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) #卷积核[in_channels,1,1],bias=False不使用偏置(默认为True) self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) #为上采样、下采样设置不同的scale_factor if mode == 'down': self.scale_factor = 0.5 elif mode == 'up': self.scale_factor = 2
forward( )
前向传播函数
def forward(self, x): #使用relu函数做非线性变换 out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) # upsample/downsample:做上/下采样 out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) #再次使用relu函数对采样后的输出做非线性变换 out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) # skip,对传入的x进行处理 x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) skip = self.skip(x) out = out + skip return out
这篇关于GFPGAN源码分析—第五篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-11有哪些好用的家政团队管理工具?
- 2025-01-11营销人必看的GTM五个指标
- 2025-01-11办公软件在直播电商前期筹划中的应用与推荐
- 2025-01-11提升组织效率:上级管理者如何优化跨部门任务分配
- 2025-01-11酒店精细化运营背后的协同工具支持
- 2025-01-11跨境电商选品全攻略:工具使用、市场数据与选品策略
- 2025-01-11数据驱动酒店管理:在线工具的核心价值解析
- 2025-01-11cursor试用出现:Too many free trial accounts used on this machine 的解决方法
- 2025-01-11百万架构师第十四课:源码分析:Spring 源码分析:深入分析IOC那些鲜为人知的细节|JavaGuide
- 2025-01-11不得不了解的高效AI办公工具API