【pytorch】读取RGB图片,并输入到简单的网络中进行处理

2022/4/13 6:19:47

本文主要是介绍【pytorch】读取RGB图片,并输入到简单的网络中进行处理,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

使用PIL读取RBG图片

from PIL import Image
image=Image.open("./xxx.png")   #读取图片
img_data = np.array(image)      #将图片转换为np对象 (此时img_data的大小为 [H,W,3],其中W为图片的宽,H为图片的高,3为RGB通道数)

将三维的RGB图片增加一维成四维

为什么要增加成四维呢?
因为pytorch中的数据为tensor(张量),而张量的描述格式为(batch_size,色彩通道数量,高度,宽度),而一张图片一般是3维结构(高度,宽度,色彩通道数量),明显差一个维度,因此需要在第一个位置增加一个维度。

此外,还注意到tensor的第二个参数为通道数,而RGB的第三个才是通道数,因此需要在此处转换一下。

转换步骤:将三个通道的数据拆开,再拼起来

img_R = img_data[:,:,0]
img_G = img_data[:,:,1]
img_B = img_data[:,:,2]
img = np.array([img_R,img_G,img_B])   # 此时img的大小为[3,H,W]

使用unsqueeze()来增加维度:x = torch.from_numpy(img).float().unsqueeze(0),其中的参数0是指“在第0个维度增加一维”

搭建一个简单的网络

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class NET(nn.Module):       # 搭建网络结构
    def __init__(self):
        super(NET, self).__init__()
        self.conv11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,padding=1)
        self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1)

    
    def forward(self,inputs):
        x11 = self.conv11(inputs)     # 卷积
        x11 = F.relu(x11)      # relu激活
        x12 = self.conv12(x11)
        x12 = F.relu(x12)
        flatten = torch.flatten(x12)     # 平坦化
        output = F.log_softmax(flatten)     # softmax处理(使用log_softmax能够防止单纯使用softmax时的边界溢出问题)
        return output

输入数据到网络中

net = NET()   # 实例化网络
output = net(img)  # 此处的img为之前经过“转换步骤”转换过的数据


这篇关于【pytorch】读取RGB图片,并输入到简单的网络中进行处理的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程