Person_reID_baseline_pytorch 源码解析之 test.py
2022/1/6 12:03:38
本文主要是介绍Person_reID_baseline_pytorch 源码解析之 test.py,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
源码中有两个用于测试的脚本: test.py 和 evaluate_gpu.py 。其中, test.py 加载通过脚本 train.py 训练好的模型,实现对 query 和 gallery 图片的特征提取;本文对脚本 test.py 进行解析。
1. 加载模型和数据
首先需要载入训练好的模型,这里以基于 Resnet50 输出类别为 751 类的行人重识别模型 ft_net 为例。
model_structure = ft_net(751) model = load_network(model_structure)
然后需要载入经过预处理的 gallery 和 query 数据集
data_transforms = transforms.Compose([ transforms.Resize((256,128), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=0) for x in ['gallery','query']}
加载预处理过的数据集和训练好的模型,然后使用函数 extract_feature 进行特征提取
with torch.no_grad(): gallery_feature = extract_feature(model,dataloaders['gallery']) query_feature = extract_feature(model,dataloaders['query'])
2. 完成特征提取
extract_feature 是 test.py 中非常重要的一个函数,用于提取图片的特征,下面对它逐行解析
def extract_feature(model,dataloaders): features = torch.FloatTensor() count = 0 # 加载数据集 for data in dataloaders: img, label = data n, c, h, w = img.size() count += n # 统计数据集图片数量 print(count) ff = torch.FloatTensor(n,512).zero_().cuda() for i in range(2): if(i==1): # 翻转图片 img = fliplr(img) # 将图片变成 Variable,准备加载到网络中 input_img = Variable(img.cuda()) # 缩放尺寸 multiple_scale for scale in ms: if scale != 1: # bicubic is only available in pytorch>= 1.1 input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic', align_corners=False) # 模型推理 outputs = model(input_img) # 拼接多尺度预测结果 ff += outputs # norm feature 特征归一化 fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) ff = ff.div(fnorm.expand_as(ff)) # 返回提取到的特征 features = torch.cat((features,ff.data.cpu()), 0) return features
3. 实现特征归一化
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
这里是在输入张量 ff 的第 1 维进行 L2-norm,即 2 范数归一化。特征向量中每个元素均除以向量的L2范数。
pytorch 中使用 torch.norm 计算张量的范数。
fnorm = torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
- input 输入张量
- p 是范数计算中的幂指数值,p = 2 时即为 2 范数
- dim 指定计算的维度,如果 dim 是整数值,则计算向量范数。当输入张量 input 超过2维,将在最后一维计算向量范数
- keepdim 指明是否保留输出张量的维度dim
- out 输出张量
- dtype 返回张量的期待数据类型
令特征向量除以向量的L2范数,expand_as 函数将范数 fnorm 扩展成张量 ff 相同的维度。
ff = ff.div(fnorm.expand_as(ff))
然后使用 tensor.div 完成除法。
Tensor.div(value, *, rounding_mode=None)
最后,使用 torch.cat 在第 0 维上拼接输入张量
features = torch.cat((features,ff.data.cpu()), 0)
4. 生成 Matlab 文件
通过上述步骤实现了 query 和 gallery 图片特征的提取,将特征矩阵存储到 pytorch_result.mat 文件中。
# Save to Matlab for check result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} scipy.io.savemat('pytorch_result.mat',result)
为了评估模型效果,还要记录图片的 label 和 camera 。
这里使用 get_id 函数通过图片名称获取 label 和 camera 信息。
def get_id(img_path): camera_id = [] labels = [] for path, v in img_path: #filename = path.split('/')[-1] filename = os.path.basename(path) label = filename[0:4] camera = filename.split('c')[1] if label[0:2]=='-1': labels.append(-1) else: labels.append(int(label)) camera_id.append(int(camera[0])) return camera_id, labels gallery_path = image_datasets['gallery'].imgs query_path = image_datasets['query'].imgs gallery_cam,gallery_label = get_id(gallery_path) query_cam,query_label = get_id(query_path)
生成的 Matlab 文件将被脚本 evaluate_gpu.py 使用,用于计算模型的评估指标。
参考链接
- pytorch求范数函数——torch.norm
- pytorch torch.norm 文档
- Pytorch expand_as()函数
- torch.cat()函数的官方解释,详解以及例子
- torch.stack()的官方解释,详解以及例子
这篇关于Person_reID_baseline_pytorch 源码解析之 test.py的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-23DevExpress 怎么实现右键菜单(Context Menu)显示中文?-icode9专业技术文章分享
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解