使用PyTorch加载自定义的图片及其标签
2020/11/3 15:03:40
本文主要是介绍使用PyTorch加载自定义的图片及其标签,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
由于我下载的imagenet2012验证集所有图片都在一个文件夹,所有标签数据都在一个txt里面,因此我使用了自定义的DataSet和DataLoader进行读取。
import os from torch.utils import data from PIL import Image import torch.nn as nn from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt
transform=transforms.Compose([ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), ])
class MyDataSet(data.Dataset): def __init__(self,root,target_transform=None): fh = open('imagenet/caffe_ilsvrc12/val.txt', 'r') imgs = [] for line in fh: line = line.rstrip() words = line.split() words[0]=os.path.join(root, words[0]) print('img path:',words[0],'label:',words[1]) imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transforms = transform self.target_transform = target_transform def __getitem__(self, index): #print('index:',index) img_path,label = self.imgs[index] pil_img = Image.open(img_path).convert('L') if self.transforms: data = self.transforms(pil_img) else: pil_img = np.asarray(pil_img) data = torch.from_numpy(pil_img) return data,label def __len__(self): return len(self.imgs)
自定义的MyDataSet类继承于torch.utils.data.DataSet类。由于图片本身一部分是三通道的,一部分却是单通道的。因此如果不在读取的时候统一读入灰度图,就会报一个错误:
RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1, 224, 224] at entry 25
我原本是想读入彩色图,以便在下面直接进行展示。这个错误我不知道如何解决,因此统一读入时使用
pil_img = Image.open(img_path).convert('L')
读取单通道图片,在后面的展示中显示的也就是灰度图片。
train_dataset = MyDataSet('imagenet/val') print(len(train_dataset)) valid_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
for image in valid_loader: valid_image, valid_label = image[0], image[1] print('valid_label:', valid_label) print('valid_image shape', valid_image.shape) print(valid_image[0].shape) plt.imshow(valid_image[0].squeeze(), cmap='gray') plt.show() break
valid_label: tensor([658, 283, 202, 619, 32, 758, 646, 690, 100, 546, 942, 728, 343, 969, 80, 530, 296, 412, 163, 128, 858, 702, 507, 500, 303, 478, 342, 10, 524, 703, 277, 777, 600, 806, 768, 353, 718, 981, 598, 519, 413, 817, 774, 302, 263, 366, 31, 600, 48, 986, 98, 602, 409, 39, 894, 747, 200, 384, 140, 386, 191, 952, 128, 990]) valid_image shape torch.Size([64, 1, 224, 224]) torch.Size([1, 224, 224])
这篇关于使用PyTorch加载自定义的图片及其标签的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-01使用 SVN合并操作时,怎么解决冲突的情况?-icode9专业技术文章分享
- 2025-01-01告别Anaconda?试试这些替代品吧
- 2024-12-31自学记录鸿蒙API 13:实现人脸比对Core Vision Face Comparator
- 2024-12-31自学记录鸿蒙 API 13:骨骼点检测应用Core Vision Skeleton Detection
- 2024-12-31自学记录鸿蒙 API 13:实现人脸检测 Core Vision Face Detector
- 2024-12-31在C++中的双端队列是什么意思,跟消息队列有关系吗?-icode9专业技术文章分享
- 2024-12-31内存泄漏(Memory Leak)是什么,有哪些原因和优化办法?-icode9专业技术文章分享
- 2024-12-31计算机中的内存分配方式堆和栈有什么关系和特点?-icode9专业技术文章分享
- 2024-12-31QT布局器的具体使用原理和作用是什么?-icode9专业技术文章分享
- 2024-12-30用PydanticAI和Gemini 2.0构建Airflow的AI助手