pytorch 中 DataLoader 和 Dataset 的使用
2021/7/26 23:09:52
本文主要是介绍pytorch 中 DataLoader 和 Dataset 的使用,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
加载顺序
pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环调用dataloader对象,获取data,label数据拿到模型中去训练
Dataset
你需要自己定义一个class继承父类Dataset,其中至少需要重写以下3个函数:
①__init__:传入数据,或者加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__: 返回一条训练数据,并将其转换成tensor
示例代码:
class MyData(Dataset): def __init__(self, x_patches, y_patches, transform = None): self.y_patches = clean_patches self.x_patches = blur_patches self.transform = transform def __len__(self): return len(self.y_patches) def __getitem__(self, idx): y_image = self.y_patches[idx] x_image = self.x_patches[idx] y_image = np.asarray(y_image) x_image = np.asarray(x_image) y_image = Image.fromarray(y_image.astype(np.uint8)) x_image = Image.fromarray(x_image.astype(np.uint8)) if self.transform: y_image = self.transform(y_image) x_image = self.transform(x_image) return x_image, y_image
DataLoader
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:这个参数可以自己操作每个batch的数据 参考:https://blog.csdn.net/kahuifu/article/details/108654421
示例代码:
dataset = MyData(x_patches, y_patches, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])) bs = 16 data_loader = DataLoader(dataset, batch_size=bs, shuffle=True) num_batches = len(data_loader)
调用DateLoader
最后循环调用dataloader ,拿到数据放入模型进行训练
for n_batch, (x_batch, y_batch) in enumerate(data_loader): x_data = x_batch.float().cuda() y_data = y_batch.float().cuda()
这篇关于pytorch 中 DataLoader 和 Dataset 的使用的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解
- 2024-12-20利用Gemini构建处理各种PDF文档的Document AI管道