PyTorch程序设计
2021/5/22 1:55:50
本文主要是介绍PyTorch程序设计,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
PyTorch程序设计
- 1. 程序示例
- 2. 加载数据
- 3. 创建网络(BN、卷积、ReLU)
- 4. 优化器(优化器比较、参数lr、权重衰减)
- 5. 损失函数比较
- 6. 模型保存与加载
1. 程序示例
import torch, os from torch.utils.data import Dataset, DataLoader import numpy as np class MyDataSet(Dataset): def __init__(self, images, labels): self.images = images self.labels = labels def __getitem__(self, item): image = self.images[item].reshape(1, 2, 2) # cwh label = self.labels[item].reshape(1, 1, 1) return image, label def __len__(self): return len(self.labels) class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() layers = [] layers.append(torch.nn.Conv2d(1, 1, kernel_size=2, stride=1, bias=True, padding=0)) self.net = torch.nn.ModuleList(layers) def forward(self, input): return self.net[0](input) def train(): # 1.输入数据,[w, h, c] = [2, 2, 1]。 x = np.random.randint(low=0, high=255, size=(50, 2, 2), dtype=np.int) x = np.divide(x, 255).astype(np.float32) # 2.权重和偏置。 w = np.array([[1, 2], [3, 4]], dtype=np.float32) b = 5 # 3.标签。 y = np.array([(np.multiply(w, x[i])).sum() + b for i in range(x.shape[0])]) # 4.加载数据、网络、优化器、损失函数。 train_data = MyDataSet(x, y) data_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False) net = Net() if os.path.exists('parameters.pth'): net.load_state_dict(torch.load('parameters.pth')) optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005) loss_op = torch.nn.L1Loss(size_average=True, reduce=True) # 5.迭代。 for epoch in range(400): optimizer.param_groups[0]['lr'] = 0.001 # 设置学习率。 for x, y in data_loader: out = net(x) loss = loss_op(y, out) optimizer.zero_grad() loss.backward() optimizer.step() loss_numpy = loss.cpu().detach().numpy() if loss_numpy < 1e-10: break print('--epoch:', epoch, 'loss:', loss_numpy) for k, v in net.named_parameters(): print(k, v.cpu().detach().numpy()) torch.save(net.state_dict(), 'parameters.pth') # 保存参数。 torch.save(net, 'model.pth') # 保存网络结构和参数。 def testNet(): net = torch.load('model.pth') y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2)) print(y) def testParams(): net = Net() net.load_state_dict(torch.load('parameters.pth')) y = net(torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).reshape(1, 1, 2, 2)) print(y) if __name__ == '__main__': train() testNet() testParams()
2. 加载数据
3. 创建网络(BN、卷积、ReLU)
4. 优化器(优化器比较、参数lr、权重衰减)
5. 损失函数比较
6. 模型保存与加载
这篇关于PyTorch程序设计的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解
- 2024-12-20利用Gemini构建处理各种PDF文档的Document AI管道