Coordinate Attention +resnet+pytorch实现
2021/5/25 10:27:32
本文主要是介绍Coordinate Attention +resnet+pytorch实现,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
# CA (coordinate attention) import torch import torch.nn as nn import math import torch.nn.functional as F from torchsummary import summary import torch.utils.model_zoo as model_zoo __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } import torch import torch.nn as nn class h_sigmoid(nn.Module): def __init__(self, inplace=True): super(h_sigmoid, self).__init__() self.relu = nn.ReLU6(inplace=inplace) def forward(self, x): return self.relu(x + 3) / 6 class h_swish(nn.Module): def __init__(self, inplace=True): super(h_swish, self).__init__() self.sigmoid = h_sigmoid(inplace=inplace) def forward(self, x): return x * self.sigmoid(x) class CoordAttention(nn.Module): def __init__(self, in_channels, out_channels, reduction=32): super(CoordAttention, self).__init__() self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1)) temp_c = max(8, in_channels // reduction) self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(temp_c) self.act1 = h_swish() self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): short = x n, c, H, W = x.shape x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2) x_cat = torch.cat([x_h, x_w], dim=2) out = self.act1(self.bn1(self.conv1(x_cat))) x_h, x_w = torch.split(out, [H, W], dim=2) x_w = x_w.permute(0, 1, 3, 2) out_h = torch.sigmoid(self.conv2(x_h)) out_w = torch.sigmoid(self.conv3(x_w)) return short * out_w * out_h # 搭建CA_ResNet34 class BottleneckBlock(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BottleneckBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) self.bn1 = norm_layer(width) self.conv2 = nn.Conv2d(width, width, 3, padding=dilation, stride=stride, groups=groups, dilation=dilation, bias=False) self.bn2 = norm_layer(width) self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU() self.downsample = downsample self.stride = stride self.ca = CoordAttention(in_channels=planes * self.expansion, out_channels=planes * self.expansion) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out = self.ca(out) # add CA out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, depth, n_class=1000, with_pool=True): super(ResNet, self).__init__() layer_cfg = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3] } layers = layer_cfg[depth] self.num_classes = n_class self.with_pool = with_pool self._norm_layer = nn.BatchNorm2d self.inplanes = 64 self.dilation = 1 self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = self._norm_layer(self.inplanes) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) if with_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if n_class > 0: self.fc = nn.Linear(512 * block.expansion, n_class) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride=stride, bias=False), norm_layer(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.with_pool: x = self.avgpool(x) if self.num_classes > 0: x = torch.flatten(x, 1) x = self.fc(x) return x def ca_resnet34(**kwargs): return ResNet(BottleneckBlock, 34, **kwargs) def resnet_CA_instance(n_class, pretrained=False, **kwargs): # resnet34的模型 model = ResNet(BottleneckBlock, 34, n_class, **kwargs) if pretrained: pretrained_dict = model_zoo.load_url(model_urls['resnet34']) model_dict = model.state_dict() # 筛除不加载的层结构 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新当前网络的结构字典 model_dict.update(pretrained_dict) model.load_state_dict(model_dict) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, n_class) # 15 output classes stdv = 1.0 / math.sqrt(1000) for p in model.fc.parameters(): p.data.uniform_(-stdv, stdv) return model # 利用高阶 API 查看模型 ca_res34 = ca_resnet34(n_class=15) print(ca_res34) x = torch.rand(1, 3, 224, 224) i = ca_res34(x) print(i.shape) summary(ca_res34, (3, 224, 224))
引用请附属作者名:叫我小张就行了
这篇关于Coordinate Attention +resnet+pytorch实现的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解
- 2024-12-20利用Gemini构建处理各种PDF文档的Document AI管道