【pytorch】读取RGB图片,并输入到简单的网络中进行处理
2022/4/13 6:19:47
本文主要是介绍【pytorch】读取RGB图片,并输入到简单的网络中进行处理,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
使用PIL读取RBG图片
from PIL import Image image=Image.open("./xxx.png") #读取图片 img_data = np.array(image) #将图片转换为np对象 (此时img_data的大小为 [H,W,3],其中W为图片的宽,H为图片的高,3为RGB通道数)
将三维的RGB图片增加一维成四维
为什么要增加成四维呢?
因为pytorch中的数据为tensor(张量),而张量的描述格式为(batch_size,色彩通道数量,高度,宽度)
,而一张图片一般是3维结构(高度,宽度,色彩通道数量)
,明显差一个维度,因此需要在第一个位置增加一个维度。
此外,还注意到tensor的第二个参数为通道数,而RGB的第三个才是通道数,因此需要在此处转换一下。
转换步骤:将三个通道的数据拆开,再拼起来
img_R = img_data[:,:,0] img_G = img_data[:,:,1] img_B = img_data[:,:,2] img = np.array([img_R,img_G,img_B]) # 此时img的大小为[3,H,W]
使用unsqueeze()来增加维度:x = torch.from_numpy(img).float().unsqueeze(0)
,其中的参数0是指“在第0个维度增加一维”
搭建一个简单的网络
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class NET(nn.Module): # 搭建网络结构 def __init__(self): super(NET, self).__init__() self.conv11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,padding=1) self.conv12 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1) def forward(self,inputs): x11 = self.conv11(inputs) # 卷积 x11 = F.relu(x11) # relu激活 x12 = self.conv12(x11) x12 = F.relu(x12) flatten = torch.flatten(x12) # 平坦化 output = F.log_softmax(flatten) # softmax处理(使用log_softmax能够防止单纯使用softmax时的边界溢出问题) return output
输入数据到网络中
net = NET() # 实例化网络 output = net(img) # 此处的img为之前经过“转换步骤”转换过的数据
这篇关于【pytorch】读取RGB图片,并输入到简单的网络中进行处理的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-25Elevate Your Lead Generation Game with Maps Scraper AI
- 2024-05-15PingCAP 黄东旭参与 CCF 秀湖会议,共探开源教育未来
- 2024-05-13PingCAP 戴涛:构建面向未来的金融核心系统
- 2024-05-09flutter3.x_macos桌面os实战
- 2024-05-09Rust中的并发性:Sync 和 Send Traits
- 2024-05-08使用Ollama和OpenWebUI在CPU上玩转Meta Llama3-8B
- 2024-05-08完工标准(DoD)与验收条件(AC)究竟有什么不同?
- 2024-05-084万 star 的 NocoDB 在 sealos 上一键起,轻松把数据库编程智能表格
- 2024-05-08Mac 版Stable Diffusion WebUI的安装
- 2024-05-08解锁CodeGeeX智能问答中3项独有的隐藏技能