YOLOX改进之添加ASFF

2022/1/6 23:04:40

本文主要是介绍YOLOX改进之添加ASFF,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

文章内容:如何在YOLOX官网代码中添加ASFF模块

环境:pytorch1.8

修改内容

(1)在PAFPN尾部添加ASFF模块(YOLOX-s等版本)

(2)在FPN尾部添加ASFF模块(YOLOX-Darknet53版本)

参考链接

论文链接:https://arxiv.org/pdf/1911.09516v2.pdf

ASFF原理及代码参考:https://blog.csdn.net/weixin_44119362/article/details/114289607

示意图如下
在这里插入图片描述

使用方法:直接在PAFPN或FPN尾部添加即可(可自动进行维度匹配,不需要修改)

代码修改过程

1、在YOLOXS版本的PAFPN后添加ASFF模块

(注意:这里是PAFPN该版本用于YOLOv5版的PAFPN中,不能用于YOLOv3的FPN)

步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF.py文件,内容如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))


class ASFF(nn.Module):
    def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
        """
        multiplier should be 1, 0.5
        which means, the channel of ASFF can be 
        512, 256, 128 -> multiplier=0.5
        1024, 512, 256 -> multiplier=1
        For even smaller, you need change code manually.
        """
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [int(1024*multiplier), int(512*multiplier),
                    int(256*multiplier)]
        # print(self.dim)
        
        self.inter_dim = self.dim[self.level]
        if level == 0:
            self.stride_level_1 = Conv(int(512*multiplier), self.inter_dim, 3, 2)
                
            self.stride_level_2 = Conv(int(256*multiplier), self.inter_dim, 3, 2)
                
            self.expand = Conv(self.inter_dim, int(
                1024*multiplier), 3, 1)
        elif level == 1:
            self.compress_level_0 = Conv(
                int(1024*multiplier), self.inter_dim, 1, 1)
            self.stride_level_2 = Conv(
                int(256*multiplier), self.inter_dim, 3, 2)
            self.expand = Conv(self.inter_dim, int(512*multiplier), 3, 1)
        elif level == 2:
            self.compress_level_0 = Conv(
                int(1024*multiplier), self.inter_dim, 1, 1)
            self.compress_level_1 = Conv(
                int(512*multiplier), self.inter_dim, 1, 1)
            self.expand = Conv(self.inter_dim, int(
                256*multiplier), 3, 1)

        # when adding rfb, we use half number of channels to save memory
        compress_c = 8 if rfb else 16
        self.weight_level_0 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = Conv(
            self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = Conv(
            self.inter_dim, compress_c, 1, 1)

        self.weight_levels = Conv(
            compress_c*3, 3, 1, 1)
        self.vis = vis

    def forward(self, x): #l,m,s
        """
        # 
        256, 512, 1024
        from small -> large
        """
        x_level_0=x[2] #最大特征层
        x_level_1=x[1] #中间特征层
        x_level_2=x[0] #最小特征层

        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_downsampled_inter = F.max_pool2d(
                x_level_2, 3, stride=2, padding=1)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
        elif self.level == 1:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=2, mode='nearest')
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)
        elif self.level == 2:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(
                level_0_compressed, scale_factor=4, mode='nearest')
            x_level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(
                x_level_1_compressed, scale_factor=2, mode='nearest')
            level_2_resized = x_level_2

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)

        levels_weight_v = torch.cat(
            (level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\
            level_1_resized * levels_weight[:, 1:2, :, :] +\
            level_2_resized * levels_weight[:, 2:, :, :]

        out = self.expand(fused_out_reduced)

        if self.vis:
            return out, levels_weight, fused_out_reduced.sum(dim=1)
        else:
            return out

步骤二:在YOLOX-main/yolox/models/yolo_pafpn.py中调用ASFF模块

(1)导入

from .ASFF import ASFF

(2)在init中实例化

        # ############ 2、实例化ASFF
        self.asff_1 = ASFF(level = 0, multiplier = width)
        self.asff_2 = ASFF(level = 1, multiplier = width)
        self.asff_3 = ASFF(level = 2, multiplier = width)

    def forward(self, input):

(3)直接在PAFPN输出outputs后接上ASFF模块

        outputs = (pan_out2, pan_out1, pan_out0)

        # asff
        pan_out0 = self.asff_1(outputs)
        pan_out1 = self.asff_2(outputs)
        pan_out2 = self.asff_3(outputs)
        outputs = (pan_out2, pan_out1, pan_out0)
        
        return outputs

2、在YOLOX-Darknet53FPN后添加ASFF模块

(注意:这里是用于YOLOv3的FPN)

步骤一:在YOLOX-main/yolox/models文件夹下创建ASFF_darknet.py文件,内容如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_blocks import BaseConv

# 输入的是从FPN得到的特征
# 输出是要给head的特征,与之前不变,注意单个ASFF只能输出一个特征,level=0对应最底层的特征,这里是512*20*20,尺度大小 level_0 < level_1 < level_2

class ASFF(nn.Module):
    def __init__(self, level, rfb=False, vis=False):
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [512, 256, 128]
        self.inter_dim = self.dim[self.level]
        if level==0:
            self.stride_level_1 = self._make_cbl(256, self.inter_dim, 3, 2)
            self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
            self.expand = self._make_cbl(self.inter_dim, 512, 3, 1)  # 输出是要给head的特征,与之前不变512-512
        elif level==1:
            self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
            self.stride_level_2 = self._make_cbl(128, self.inter_dim, 3, 2)
            self.expand = self._make_cbl(self.inter_dim, 256, 3, 1)  # 输出是要给head的特征,与之前不变256-256
        elif level==2:
            self.compress_level_0 = self._make_cbl(512, self.inter_dim, 1, 1)
            self.compress_level_1 = self._make_cbl(256, self.inter_dim, 1, 1)
            self.expand = self._make_cbl(self.inter_dim, 128, 3, 1)  # 输出是要给head的特征,与之前不变128-128

        compress_c = 8 if rfb else 16  #when adding rfb, we use half number of channels to save memory

        self.weight_level_0 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
        self.weight_level_1 = self._make_cbl(self.inter_dim, compress_c, 1, 1)
        self.weight_level_2 = self._make_cbl(self.inter_dim, compress_c, 1, 1)

        self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
        self.vis= vis


    def _make_cbl(self, _in, _out, ks, stride):
        return BaseConv(_in, _out, ks, stride, act="lrelu")

    def forward(self, x_level_0, x_level_1, x_level_2):   # 输入3个维度(512*20*20,256*40*40,128*80*80),输出也是
        if self.level==0:
            level_0_resized = x_level_0  # (512*20*20)
            level_1_resized = self.stride_level_1(x_level_1)  # (256*40*40->512*20*20)

            level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)  # (128*80*80->128*40*40)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)  # (128*40*40->512*20*20)

        elif self.level==1:
            level_0_compressed = self.compress_level_0(x_level_0)  # (512*20*20->256*20*20)
            level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')  # (256*20*20->256*40*40)
            level_1_resized =x_level_1  # (256*40*40)
            level_2_resized =self.stride_level_2(x_level_2)  # (128*80*80->256*40*40)
        elif self.level==2:
            level_0_compressed = self.compress_level_0(x_level_0)  # (512*20*20->128*20*20)
            level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')  # (128*20*20->128*80*80)
            level_1_compressed = self.compress_level_1(x_level_1)  # (256*40*40->128*40*40)
            level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')  # (128*40*40->128*80*80)
            level_2_resized =x_level_2  # (128*80*80)

        level_0_weight_v = self.weight_level_0(level_0_resized)  # 
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\
                            level_1_resized * levels_weight[:,1:2,:,:]+\
                            level_2_resized * levels_weight[:,2:,:,:]

        out = self.expand(fused_out_reduced)

        if self.vis:
            return out, levels_weight, fused_out_reduced.sum(dim=1)
        else:
            return out

步骤二:在YOLOX-main/yolox/models/yolo_fpn.py中调用ASFF模块

(1)导入

from .ASFF_darknet import ASFF  ### 1、导入

(2)实例化ASFF对象

        #######################  2、实例化ASFF
        self.assf_5 = ASFF(level = 0)
        self.assf_4 = ASFF(level = 1)
        self.assf_3 = ASFF(level = 2)
        ########################

    def _make_cbl(self, _in, _out, ks):
        return BaseConv(_in, _out, ks, stride=1, act="lrelu")

(3)在outputs后直接添加asff

        outputs = (out_dark3, out_dark4, x0)  # 特征图尺度逐渐变小(128,256,512)  ### 该行为初始的FPN输出,使用ASFF则注释掉

        ################################################
        # 3、对FPN特征金字塔进行ASFF操作,注释掉原FPN输出outpus
        
        out_assf_5 = self.assf_5(x0, out_dark4, out_dark3)
        out_assf_4 = self.assf_4(x0, out_dark4, out_dark3)
        out_assf_3 = self.assf_3(x0, out_dark4, out_dark3)

        outputs = (out_assf_3, out_assf_4, out_assf_5)
        #################################################
        return outputs

效果:根据个人数据集而定。对我的数据集没变化。

权重大小变化:yoloxs(68.8M->110M)

速度变化:有所下降

上述代码链接
链接:https://pan.baidu.com/s/1ykfb-YHpJaLj4sQpMsCIKw
提取码:qrvg



这篇关于YOLOX改进之添加ASFF的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程