使用PyTorch进行深度学习任务的主要流程
2021/10/17 23:12:53
本文主要是介绍使用PyTorch进行深度学习任务的主要流程,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
引入
完成一项机器学习任务时的步骤:
- 数据准备(可以导入、也可以通过爬虫爬取)。
- 数据预处理(数据格式的统一、必要的数据转换),并划分训练集和测试集。
- 选择模型,并设定损失函数和优化函数以及对应的超参数。
- 用模型拟合训练集数据,在验证集/测试集上计算模型表现。
- 利用可视化,对训练结果进行评价。
深度学习和机器学习的差异:
- 代码实现上,深度学习样本量大;batch训练策略需要在训练时每次读取固定数量的样本。
- 模型训练上,深度神经网络层数较多,有一些用于实现特定功能的层(如卷积层、池化层、批正则化层、LSTM层等),需要进行定制化。
- 训练时,深度学习需要“放入”GPU进行训练,将损失函数反向传播回网络最前面的层,同时使用优化器调整网络参数。后续计算一些指标还需要把数据“放回”CPU。
深度学习任务步骤:
1.基本配置
导入必须的包
import os import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.optim as optimizer
设置超参数
batch_size = 16 # batch中的样本数 lr = 1e-4 # 初始学习率 max_epochs = 100 # 训练次数
配置GPU
# 方案一:使用x,这种情况如果使用GPU不需要设置 os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Only device 1 will be seen os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # Devices 0 and 1 will be visible # 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
2.数据读入
构建Dataset
1.直接使用PyTorch仓库中准备好的数据
FashionMNIST 是一个替代 MNIST 手写数字 的图像数据集
training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() )
root
:训练/测试数据的存储路径train
:强调是训练集还是测试集download=True
:如果根路径的数据不可用,就从互联网下载数据transform
andtarget_transform
:指定特征和标签的转换
2.自定义dataset类进行数据的读取以及初始化
__init__
: 用于向类中传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数
使用DataLoader按批次读入数据
DataLoader使用iterative的方式不断读入批次数据
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
- batch_size:样本是按“批”读入的,batch_size就是每次读入的样本数
- num_workers:有多少个进程用于读取数据
- shuffle:是否将读入的数据打乱
- drop_last:对于样本最后一部分没有达到批次数的样本,不再参与训练
3.模型构建
神经网络的构造
PyTorch中神经网络构造一般是基于 Module 类的模型来完成的,它让模型构造更加灵活。
Module 类是 nn 模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。
下面为继承 Module 类构造多层感知机的示例。这里定义的 MLP 类重载了 Module 类的 init 函数和 forward 函数。它们分别用于创建模型参数和定义前向计算(正向传播)。
import torch from torch import nn class MLP(nn.Module): # 声明带有模型参数的层,这里声明了两个全连接层 def __init__(self, **kwargs): # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例例时还可以指定其他函数 super(MLP, self).__init__(**kwargs) self.hidden = nn.Linear(784, 256) self.act = nn.ReLU() self.output = nn.Linear(256,10) # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出 def forward(self, x): o = self.act(self.hidden(x)) return self.output(o)
神经网络中常见的层
深度学习的一个魅力在于神经网络中各式各样的层,例如全连接层、卷积层、池化层与循环层等等。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。
下面是使用 Module 来自定义层,从而可以被反复调用的示例:
-
不含模型参数的层
import torch from torch import nn class MyLayer(nn.Module): def __init__(self, **kwargs): super(MyLayer, self).__init__(**kwargs) def forward(self, x): return x - x.mean()
-
含模型参数的层
Parameter 类其实是 Tensor 的子类,如果一 个 Tensor 是 Parameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成 Parameter,除了直接定义成 Parameter 类外,还可以使用ParameterList 和 ParameterDict 分别定义参数的列表和字典。
class MyListDense(nn.Module): def __init__(self): super(MyListDense, self).__init__() self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)]) self.params.append(nn.Parameter(torch.randn(4, 1))) def forward(self, x): for i in range(len(self.params)): x = torch.mm(x, self.params[i]) return x net = MyListDense() print(net)
-
二维卷积层
二维卷积层将输入和卷积核做互相关运算,并加上一个标量偏差来得到输出。卷积层的模型参数包括了卷积核和标量偏差。在训练模型的时候,通常我们先对卷积核随机初始化,然后不不断迭代卷积核和偏差。
-
池化层
池化层每次对输入数据的一个固定形状窗口(又称池化窗口)中的元素计算输出。不同于卷积层里计算输入和核的互相关性,池化层直接计算池化窗口内元素的最大值或者平均值。
4.损失函数
损失函数 | 功能 | 代码 |
---|---|---|
二分类交叉熵损失函数 | 计算二分类任务时的交叉熵(Cross Entropy)函数。 | torch.nn.BCELoss() |
交叉熵损失函数 | 计算交叉熵的函数。 | torch.nn.CrossEntropyLoss() |
L1损失函数 | 计算输出y 和真实标签target 之间的差值的绝对值。 |
torch.nn.L1Loss() |
MSE损失函数 | 计算输出y 和真实标签target 之差的平方。 |
torch.nn.MSELoss() |
平滑L1 (Smooth L1)损失函数 | L1的平滑输出,其功能是减轻离群点带来的影响。 | torch.nn.SmoothL1Loss() |
目标泊松分布的负对数似然损失函数 | 泊松分布的负对数似然损失函数。 | torch.nn.PoissonNLLLoss() |
KL散度 | 计算相对熵。用于连续分布的距离度量,并且对离散采用的连续输出空间分布进行回归通常很有用。 | torch.nn.KLDivLoss() |
MarginRankingLoss | 计算两个向量之间的相似度,用于排序任务。该方法计算两组数据之间的差异。 | torch.nn.MarginRankingLoss() |
多标签边界损失函数 | 对于多标签分类问题计算损失函数。 | torch.nn.MultiLabelMarginLoss() |
二分类损失函数 | 计算二分类的 logistic 损失。 | torch.nn.SoftMarginLoss() |
多分类的折页损失 | 计算多分类的折页损失。 | torch.nn.MultiMarginLoss() |
三元组损失 | 计算三元组损失。 | torch.nn.TripletMarginLoss() |
HingEmbeddingLoss | 对输出的embedding结果做Hing损失计算。 | torch.nn.HingeEmbeddingLoss() |
余弦相似度 | 对于两个向量做余弦相似度损失计算。 | torch.nn.CosineEmbeddingLoss() |
CTC损失函数 | 用于解决时序类数据的分类。计算连续时间序列和目标序列之间的损失。 | torch.nn.CTCLoss() |
5.优化器
优化器就是根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签。
Pytorch提供的优化器库torch.optim:
-
torch.optim.ASGD
-
torch.optim.Adadelta
-
torch.optim.Adagrad
-
torch.optim.Adam
-
torch.optim.AdamW
-
torch.optim.Adamax
-
torch.optim.LBFGS
-
torch.optim.RMSprop
-
torch.optim.Rprop
-
torch.optim.SGD
-
torch.optim.SparseAdam
优化器的选择是需要根据模型进行改变的,不存在绝对的好坏之分,需要多进行一些测试。
6.训练与评估
在PyTorch中设置模型状态:
- 训练状态,模型的参数支持反向传播的修改;
- 验证/测试状态,则不能修改模型参数。
model.train() # 训练状态 model.eval() # 验证/测试状态
用for循环读取DataLoader中的全部数据。
for data, label in train_loader:
之后将数据放到GPU上用于后续计算,此处以.cuda()为例
data, label = data.cuda(), label.cuda()
开始用当前批次数据做训练时,应当先将优化器的梯度置零:
optimizer.zero_grad()
之后将data送入模型中训练:
output = model(data)
根据预先定义的criterion计算损失函数:
loss = criterion(output, label)
将loss反向传播回网络:
loss.backward()
使用优化器更新模型参数:
optimizer.step()
这样一个训练过程就完成了,后续还可以计算模型准确率等指标。
验证/测试的流程基本与训练过程一致,不同点在于:
- 需要预先设置torch.no_grad,以及将model调至eval模式
- 不需要将优化器的梯度置零
- 不需要将loss反向回传到网络
- 不需要更新模型参数
7.可视化
某些任务在训练完成后,需要对一些必要的内容进行可视化,比如分类的ROC曲线,卷积网络中的卷积核,以及训练/验证过程的损失函数曲线等。
参考链接:
- Datawhale深入浅出PyTorch第三章:https://github.com/datawhalechina/thorough-pytorch/tree/main/%E7%AC%AC%E4%B8%89%E7%AB%A0%20PyTorch%E7%9A%84%E4%B8%BB%E8%A6%81%E7%BB%84%E6%88%90%E6%A8%A1%E5%9D%97
- PyTorch官方文档:https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#:~:text=PyTorch%20provides%20two%20data%20primitives%3A%20torch.utils.data.DataLoader%20and%20torch.utils.data.Dataset,Dataset%20to%20enable%20easy%20access%20to%20the%20samples.
- Dataset类的使用:https://www.jianshu.com/p/4818a1a4b5bd
这篇关于使用PyTorch进行深度学习任务的主要流程的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-27Rocket消息队列资料:新手入门指南
- 2024-11-27rocket消息队资料详解与入门指南
- 2024-11-27RocketMQ底层原理资料详解入门教程
- 2024-11-27RocketMQ项目开发资料:新手入门教程
- 2024-11-27RocketMQ项目开发资料详解
- 2024-11-27RocketMQ消息中间件资料入门教程
- 2024-11-27初学者指南:深入了解RocketMQ源码资料
- 2024-11-27Rocket消息队列学习入门指南
- 2024-11-26Rocket消息中间件教程:新手入门详解
- 2024-11-26RocketMQ项目开发教程:新手入门指南