unet 网络
2021/10/31 23:13:39
本文主要是介绍unet 网络,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
问题1:'Keyword argument not understood:', 'input
删去input=,output=
论文的几个创新点:数据增强,
论文用的一些方法:数据增强,数据的二值化
神经网络的训练:
遇到的一些问题:1、读取数据时一开始用libtiff读取tif格式图片一直导入不成功,后来更换成cv.imgread
2.tensorflow 调用adam函数时出现错误,通过查阅资料得知是Adam.keras版本需要匹配。
网络的版本号
主要结构:一个语义分割模型,encoder-decoder结构,u字形(论文中的输入大小是512*512,但这副图里给的是572*572,图片数据经过处理
特点:U型结构和skip-connection
蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。
UNet的encoder下采样4次,一共下采样16倍,对称地,其decoder也相应上采样4次,将encoder得到的高级语义特征图恢复到原图片的分辨率。
Skip connection:打破了网络的对称性,提升了网络的表征能力,关于对称性引发的特征退化问题 残差连接(skip connect)/(residual connections)_赵凯月的博客-CSDN博客
医疗影像有什么样的特点:图像语义较为简单、结构较为固定。我们做脑的,就用脑CT和脑MRI,做胸片的只用胸片CT,做眼底的只用眼底OCT,都是一个固定的器官的成像,而不是全身的。由于器官本身结构固定和语义信息没有特别丰富,所以高级语义信息和低级特征都显得很重要(UNet的skip connection(残差连接)和U型结构就派上了用场)。举两个例子直观感受下。
U-net 基于pytorch的代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_conv2d_bn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
初始化网络
super(double_conv2d_bn, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=strides, padding=padding, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=kernel_size,
stride=strides, padding=padding, bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
return out
class deconv2d_bn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
super(deconv2d_bn, self).__init__()
self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=strides, bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
return out
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.layer1_conv = double_conv2d_bn(1, 8)
self.layer2_conv = double_conv2d_bn(8, 16)
self.layer3_conv = double_conv2d_bn(16, 32)
self.layer4_conv = double_conv2d_bn(32, 64)
self.layer5_conv = double_conv2d_bn(64, 128)
self.layer6_conv = double_conv2d_bn(128, 64)
self.layer7_conv = double_conv2d_bn(64, 32)
self.layer8_conv = double_conv2d_bn(32, 16)
self.layer9_conv = double_conv2d_bn(16, 8)
self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,
stride=1, padding=1, bias=True)
self.deconv1 = deconv2d_bn(128, 64)
self.deconv2 = deconv2d_bn(64, 32)
self.deconv3 = deconv2d_bn(32, 16)
self.deconv4 = deconv2d_bn(16, 8)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
conv1 = self.layer1_conv(x)
pool1 = F.max_pool2d(conv1, 2)
conv2 = self.layer2_conv(pool1)
pool2 = F.max_pool2d(conv2, 2)
conv3 = self.layer3_conv(pool2)
pool3 = F.max_pool2d(conv3, 2)
conv4 = self.layer4_conv(pool3)
pool4 = F.max_pool2d(conv4, 2)
conv5 = self.layer5_conv(pool4)
convt1 = self.deconv1(conv5)
concat1 = torch.cat([convt1, conv4], dim=1)
conv6 = self.layer6_conv(concat1)
convt2 = self.deconv2(conv6)
concat2 = torch.cat([convt2, conv3], dim=1)
conv7 = self.layer7_conv(concat2)
convt3 = self.deconv3(conv7)
concat3 = torch.cat([convt3, conv2], dim=1)
conv8 = self.layer8_conv(concat3)
convt4 = self.deconv4(conv8)
concat4 = torch.cat([convt4, conv1], dim=1)
conv9 = self.layer9_conv(concat4)
outp = self.layer10_conv(conv9)
outp = self.sigmoid(outp)
return outp
model = Unet()
inp = torch.rand(10, 1, 224, 224)
outp = model(inp)
print(outp.shape)
这篇关于unet 网络的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23Springboot应用的多环境打包入门
- 2024-11-23Springboot应用的生产发布入门教程
- 2024-11-23Python编程入门指南
- 2024-11-23Java创业入门:从零开始的编程之旅
- 2024-11-23Java创业入门:新手必读的Java编程与创业指南
- 2024-11-23Java对接阿里云智能语音服务入门详解
- 2024-11-23Java对接阿里云智能语音服务入门教程
- 2024-11-23JAVA对接阿里云智能语音服务入门教程
- 2024-11-23Java副业入门:初学者的简单教程
- 2024-11-23JAVA副业入门:初学者的实战指南