PyTorch 深度度量学习无敌 Buff:九大模块、随意调用
2021/3/19 5:11:51
本文主要是介绍PyTorch 深度度量学习无敌 Buff:九大模块、随意调用,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
内容导读从度量学习到深度度量学习,本文介绍了一个 PyTorch 中的程序包,它可以极大简化使用深度度量学习的难度。
本文首发自微信公众号「PyTorch 开发者社区」
度量学习(Metric Learning)是机器学习过程中经常用到的一种方法,它可以借助一系列观测,构造出对应的度量函数,从而学习数据间的距离或差异,有效地描述样本之间的相似度。
CUB200 数据集样本示例,常被用作度量学习的 benchmark
这个度量函数对于相似度高的观测值,会返回一个小的距离值;对于差异巨大的观测值,则会返回一个大的距离值。
当样本量不大时,度量学习在处理分类任务的准确率和高效率上,展现出了显著优势。
DML:为多类别、小样本分类而生
然而,如果要处理的分类任务十分复杂,具有多类别、小样本等特征时,结合深度学习和度量学习的深度度量学习((Deep Metric Learning,简称 DML)),才是真正的王者。
深度度量学习又被称为距离度量学习(Distance Metric Learning)。相较于度量学习,深度度量学习可以对输入特征做非线性映射。
通过训练一个基于 CNN 的非线性特征提取模块或编码器,深度度量学习可以将提取的图像特征(Embedding)嵌入到近邻位置,同时借助欧氏距离、cosine 等距离度量方法,将不同的图像特征区分开来。
接下来,深度度量学习再结合 k 最近邻、支持向量机等分类算法,就可以在不考虑类别数量的基础上,利用提取的图像特征,来完成目标识别任务了。
import numpy as np # 随机定义A, B 两个向量 A = np.random.randn(10) B = np.random.randn(10) # 欧几里得距离(Euclidean distance) dist = np.square(np.sum(A - B)**2) # 曼哈顿距离(Manhattan distance) dist = np.sum(np.abs(A - B)) # 切比雪夫距离(Chebyshev distance) dist = np.max(np.abs(A - B)) # cosine距离 similarity = (np.sum(A * B))/(np.linalg.norm(A)) / (np.linalg.norm(A))
深度度量学习中的常用距离函数
深度度量学习在 CV 领域的一些极端分类任务(类别众多、样本量不足)中表现优异,应用遍及人脸识别、行人重识别、图像检索、目标跟踪、特征匹配等场景。
以往要在程序中使用深度度量学习,主要依赖工程师从零到一写代码,不光耗时久还容易出 bug。现在则可以依赖一个封装了多个常用模块的开源库,直接进行调用,省时省力。
PML:让深度度量学习易如反掌
pytorch-metric-learning(PML)是一个开源库,可以让各种繁琐复杂的深度度量学习算法,变得更加简单友好。
pytorch-metric-learning 具有两大特点
1、易于使用
只需添加 2 行代码,就可以在程序中使用度量学习;调用单个函数,就可以挖掘 pairs 和 triplets。
2、高度灵活
融合了 loss、miner、trainer 等多个模块,可以在已有代码中实现各种算法组合。
PML 包括 9 个模块,每个模块既可以单独使用,也可以组合成一个完整的训练/测试 workflow
pytorch-metric-learning 中的 9 大模块
1、Loss:可以应用的各种损失函数
from pytorch_metric_learning.distances import CosineSimilarity from pytorch_metric_learning.reducers import ThresholdReducer from pytorch_metric_learning.regularizers import LpRegularizer from pytorch_metric_learning import losses loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(), reducer = ThresholdReducer(high=0.3), embedding_regularizer = LpRegularizer())
自定义损失函数 TripletMarginLoss 代码示例
2、Distance: 包括计算 pairwise distance 或输入 embedding 之间相似性的各种类别
3、Reducer:从几个损失值变为单个损失值
4、Regularizer:对权重和嵌入向量进行正则化
5、Miner:PML 提供两种类型的挖掘函数:子集批处理 miner 及 tuple miner
from pytorch_metric_learning import miners, losses miner = miners.MultiSimilarityMiner() loss_func = losses.TripletMarginLoss() # your training loop for i, (data, labels) in enumerate(dataloader): optimizer.zero_grad() embeddings = model(data) hard_pairs = miner(embeddings, labels) loss = loss_func(embeddings, labels, hard_pairs) loss.backward() optimizer.step()
用 Tripletmurginloss 损失函数添加挖掘功能
6、Sampler:_torch.utils.data.Sampler_ 类的扩展,决定样本的 batch 的组成形态
7、Trainer:提供对度量学习算法的访问,如数据增强、附加网络等
8、Tester:输入模型和数据集,找到基于最近邻的准确度指标(使用该模块需要安装 faiss 安装包)
9、Util:
- _AccuracyCalculator_:给定一个 query 和推理嵌入向量(reference embedding),计算数个准确度指标
- _Inference model_:_utils.inference_ 包含用于在 batch 或一组 pair 中,找到匹配对(matching pairs )的类
- _Logging Preset_:提供日志数据 hook,模型训练、验证和存储期间的提前停止日志。
损失函数可以自定义使用 Distance、Reducer 及 Regularizer 三个模块
PML 上手实践
PyTorch 版本要求
pytorch-metric-learning v0.9.90 版本及以上:torch ≥ 1.6
pytorch-metric-learning v0.9.90 版本以下:没有版本要求,但是测试版本 torch ≥ 1.2
Pip
pip install pytorch-metric-learning
获得最新版本
pip install pytorch-metric-learning --pre
在 Windows 上安装
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html pip install pytorch-metric-learning
增加评估和日志功能,需要安装 faiss-gpu 的非官方 pypi 版本
pip install pytorch-metric-learning[with-hooks]
或 faiss-CPU
pip install pytorch-metric-learning[with-hooks-cpu]
Conda
conda install pytorch-metric-learning -c metric-learning -c pytorch
GitHub 地址:
https://github.com/KevinMusgr...
Google Colab:
https://github.com/KevinMusgr...
相关论文:
https://arxiv.org/pdf/2008.09...
参考:http://html.rhhz.net/tis/html...
https://analyticsindiamag.com...
这篇关于PyTorch 深度度量学习无敌 Buff:九大模块、随意调用的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解
- 2024-12-20利用Gemini构建处理各种PDF文档的Document AI管道