GFPGAN源码分析—第五篇

2021/12/7 1:18:22

本文主要是介绍GFPGAN源码分析—第五篇,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

2021SC@SDUSC

源码:archs\gfpganv1_clean_arch.py

本篇主要分析gfpganv1_clean_arch.py下的以下两个类

class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):StyleGan

class ResBlock(nn.Module):残差网络

目录

class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):

_init_( )

forward( )

latents with Style MLP layer">(1) style codes -> latents with Style MLP layer

(2)noise

(3) style truncation

(4)get style latent with injection

class ResBlock(nn.Module):

_init_( )

forward( )


class StyleGAN2GeneratorCSFT (StyleGAN2GeneratorClean):

继承了StyleGAN2GeneratorClean类

_init_( )

def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
    super(StyleGAN2GeneratorCSFT, self).__init__(
        out_size,
        num_style_feat=num_style_feat,
        num_mlp=num_mlp,
        channel_multiplier=channel_multiplier,
        narrow=narrow)

    self.sft_half = sft_half

forward( )

参数:

(self,
 styles,#(list[Tensor]): Sample codes of styles.
 conditions,
 input_is_latent=False,#(bool): Whether input is latent style.
 noise=None,#(Tensor | None): Input noise or None. 
 randomize_noise=True,#(bool): Randomize noise, used when 'noise' is False.
 truncation=1,
 truncation_latent=None,
 inject_index=None,#The injection index for mixing noise.
 return_latents=False)# Whether to return style latents.

(1) style codes -> latents with Style MLP layer

if not input_is_latent:
    styles = [self.style_mlp(s) for s in styles]

(2)noise

if noise is None:
    if randomize_noise:
        noise = [None] * self.num_layers  # for each style conv layer
    else:  # use the stored noise
        noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]

(3) style truncation

 if truncation < 1:
            style_truncation = []
            for style in styles:
                style_truncation.append(truncation_latent + truncation * (style - 
                truncation_latent))
            styles = style_truncation

(4)get style latent with injection

if len(styles) == 1:
    inject_index = self.num_latent

    if styles[0].ndim < 3:
        # repeat latent code for all the layers
        latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
    else:  # used for encoder with different latent code for each layer
        latent = styles[0]
elif len(styles) == 2:  # mixing noises
    if inject_index is None:
        inject_index = random.randint(1, self.num_latent - 1)
    latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
    latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
    latent = torch.cat([latent1, latent2], 1)

class ResBlock(nn.Module):

带有上/下采样的残差网络

Residual block with upsampling/downsampling

实际的一个单元(unit)即:con-relu-padding-con-relu

resNet本质上是为网络加了一个shortcut,相当于部分层数变成了一个直连接,从而防止出现升高神经网络层数反而效果变差的情况。所以需要让输入与输出的shape,channels都要保持一致。至于是否要退化部分网络,是由网络根据训练效果自身去选择的。

_init_( )

def __init__(self, in_channels, out_channels, mode='down'):
    super(ResBlock, self).__init__()
    
	#输入的通道数:in_channels
    #输出的通道数:out_channels
    #搭建卷积神经网络:卷积核[in_channels,3,3];步长为1;使用padding=1,边界增加一圈
    #padding=1保持输出与输入大小保持一致
    self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
    self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
    
    #卷积核[in_channels,1,1],bias=False不使用偏置(默认为True)
    self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
    
    #为上采样、下采样设置不同的scale_factor
    if mode == 'down':
        self.scale_factor = 0.5
    elif mode == 'up':
        self.scale_factor = 2

forward( )

前向传播函数

def forward(self, x):
    #使用relu函数做非线性变换
    out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
    # upsample/downsample:做上/下采样
    out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
    #再次使用relu函数对采样后的输出做非线性变换
    out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
    # skip,对传入的x进行处理
    x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
    skip = self.skip(x)
    out = out + skip
    return out



这篇关于GFPGAN源码分析—第五篇的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程