paddlepaddle 9 MC Dropout的使用
2022/2/8 23:50:22
本文主要是介绍paddlepaddle 9 MC Dropout的使用,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
MC Dropout是指蒙特卡罗Dropout,其可以在不改就网络结构与增加训练的情况下在测试阶段提升模型的性能,本质就是在测试时将dropout一直处于激活阶段。对网络进行多次前向传播,由于dropout每一次激活的神经元都不同,使得每次的结果都会不一样。将多次输出的结果取平均值,可以在一定程度上提升算法的准确性,但是会降低算法的推理速度。
在paddlepaddle 中是无法在测试阶段将dropout处于激活状态的,其根本原因是paddlepaddle中的dropout中的参数列表中无法指定train与eval,其参数列表如下所示:paddle.nn.Dropout(p=0.5, axis=None, mode="upscale_in_train”, name=None)
-
p (float): 将输入节点置为0的概率, 即丢弃概率。默认: 0.5。
-
axis (int|list): 指定对输入 Tensor 进行Dropout操作的轴。默认: None。
-
mode (str): 丢弃单元的方式,有两种'upscale_in_train'和'downscale_in_infer',默认: 'upscale_in_train'。计算方法如下:
-
upscale_in_train, 在训练时增大输出结果。
-
train: out = input * mask / ( 1.0 - p )
-
inference: out = input
-
-
downscale_in_infer, 在预测时减小输出结果
-
train: out = input * mask
-
inference: out = input * (1.0 - p)
-
-
-
name (str,可选): 操作的名称(可选,默认值为None)。更多信息请参见 Name
因此,要实现dropout的激活状态只能通过model.train()来使模型中的dropout处于激活状态,但是设置model.train()后,笔者发现模型前向传播过程中的gpu占用是无法被清除,无论batch_size调多小,只要测试的数据一多,就会导致显存不够用。因此,参考模型训练过程的显存清空方式,实现MC_Dropout。
import paddle #paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0}) paddle.set_flags({'FLAGS_fast_eager_deletion_mode': True })#使用快速垃圾回收策略,实际中没有任何作用 def MC_Dropout(model,data,times=10):#蒙特卡罗Dropout model.train() #为了反向传播中的清空显存功能,learning_rate为0表示不让模型进行参数更新 optim = paddle.optimizer.Adam(learning_rate=0.0,parameters=model.parameters()) loss_fn = paddle.nn.CrossEntropyLoss(soft_label=True)#使用软标签 result=[] for i in range(times): out=model(x_data) preds=paddle.nn.functional.softmax(out) result.append(preds.numpy()) #借用反向传播释放内存 loss = loss_fn(out, out) loss.backward() optim.step() optim.clear_grad() result=np.array(result)#shape为t,b,c t:times,b:batch_size,c:class_probability result=np.transpose(result,(1,2,0))#shape为b,c,t result=result.sum(axis=-1)#shape为b,c result=result.argmax(axis=-1)#shape为b,c return result model=paddle.jit.load("model/ep125_loss0.400336_acc0.9306model") model.train() Imagetest=ImageClsTestDataset((256,256),"data/data10954/cat_12_test") BATCH_SIZE=12 # 如果要加载内置数据集,将 custom_dataset 换为 train_dataset 即可 train_loader = paddle.io.DataLoader(Imagetest, batch_size=BATCH_SIZE, shuffle=False) print('=============train model=============') count=0 results=[] name_list=[] for batch_id, data in enumerate(train_loader()): x_data = data[0] names = data[1] results+=MC_Dropout(model,x_data,times=1).tolist() name_list+=names count+=len(names) print(count) print(results)
显存清空方式说明: 基于paddle.optimizer和loss函数进行假反向传播(让学习率为0,loss为0),使优化器不会对模型的参数进行实质更新。
数据加载器ImageClsTestDataset的实现:传入图片路径,自动加载到list中,无需生成txt列表
import paddle from paddle.io import Dataset from paddle.vision import transforms from PIL import Image import numpy as np import os class ImageClsTestDataset(Dataset): def __init__(self,input_shape,root): super(ImageClsTestDataset, self).__init__() self.input_shape=input_shape #ToTensor将形状为 (H x W x C)的输入数据 PIL.Image 或 numpy.ndarray 转换为 (C x H x W),并进行归一化。如果想保持形状不变,可以将参数 data_format 设置为 'HWC' #在paddle模型中数据是CHW的格式 self.preprocess_image=transforms.Compose([ transforms.Resize(input_shape), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),data_format='CHW') ]) self.root=root self.lists=os.listdir(root) self.length = len(self.lists) def __getitem__(self, index): name = self.root+'/'+self.lists[index] image = Image.open(name) np_img = np.array(image) if len(np_img.shape)==2:#防止数据中存在灰度图 tmp=np.ones((*np_img.shape,3)) tmp[:,:,0]=np_img tmp[:,:,1]=np_img tmp[:,:,2]=np_img np_img=tmp np_img=np_img.astype(np.uint8) image = self.preprocess_image(np_img[:,:,:3])#np_img[:,:,:3]防止数据中存在RGBA的四通道数据 return image, name def __len__(self): return self.length
这篇关于paddlepaddle 9 MC Dropout的使用的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-10Rakuten 乐天积分系统从 Cassandra 到 TiDB 的选型与实战
- 2025-01-09CMS内容管理系统是什么?如何选择适合你的平台?
- 2025-01-08CCPM如何缩短项目周期并降低风险?
- 2025-01-08Omnivore 替代品 Readeck 安装与使用教程
- 2025-01-07Cursor 收费太贵?3分钟教你接入超低价 DeepSeek-V3,代码质量逼近 Claude 3.5
- 2025-01-06PingCAP 连续两年入选 Gartner 云数据库管理系统魔力象限“荣誉提及”
- 2025-01-05Easysearch 可搜索快照功能,看这篇就够了
- 2025-01-04BOT+EPC模式在基础设施项目中的应用与优势
- 2025-01-03用LangChain构建会检索和搜索的智能聊天机器人指南
- 2025-01-03图像文字理解,OCR、大模型还是多模态模型?PalliGema2在QLoRA技术上的微调与应用