对抗训练fgm和pgd原理和源码分析
2021/11/24 17:12:23
本文主要是介绍对抗训练fgm和pgd原理和源码分析,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
当前,在各大NLP竞赛中,对抗训练已然成为上分神器,尤其是fgm和pgd使用较多,下面来说说吧。对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力。
fgm
FGM的全称是Fast Gradient Method, 出现于Adversarial Training Methods for Semi-supervised Text Classification这篇论文,FGM是根据具体的梯度进行scale,得到更好的对抗样本:
整个对抗训练的过程如下,伪代码如下:
- 1.计算x的前向loss、反向传播得到梯度;
- 2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r;
- 3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上;
- 4.将embedding恢复为(1)时的值;
- 5.根据(3)的梯度对参数进行更新。
fgm代码实现如下:
class FGM: def __init__(self, model: nn.Module, eps=1.): self.model = ( model.module if hasattr(model, "module") else model ) self.eps = eps self.backup = {} # only attack word embedding def attack(self, emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm and not torch.isnan(norm): r_at = self.eps * param.grad / norm param.data.add_(r_at) def restore(self, emb_name='word_embeddings'): for name, para in self.model.named_parameters(): if para.requires_grad and emb_name in name: assert name in self.backup para.data = self.backup[name] self.backup = {}
fgm应用代码如下:
##对应第一步 loss = model(**batch_data)[0] loss.backward() ##对应第二步 fgm.attack() #对应第三步 loss_adv = model(**batch_data)[0] loss_adv.backward() #对应第四步 fgm.restore() #对应第五步 optimizer.step()
pgd
FGM直接通过epsilon参数一下子算出了对抗扰动,这样得到的可能不是最优的。因此PGD进行了改进,多迭代几次,慢慢找到最优的扰动。
引用:
FGM简单粗暴的“一步到位”,可能走不到约束内的最优点。PGD则是“小步走,多走几步”,如果走出了扰动半径为epsilon的空间,就映射回“球面”上,以保证扰动不要过大
并且
pgd整个对抗训练的过程如下,伪代码如下:
- 1.计算x的前向loss、反向传播得到梯度并备份;
- 2.对于每步t:
-
a.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内);
-
if t 不是最后一步,则进行b步骤:将模型梯度归0,根据a的x+r计算前后向并得到梯度,继续a步骤;if t 是最后一步,则进行c步骤:恢复(1)的梯度,根据a的x+r计算前后向得到梯度并将梯度累加到(1)的梯度上,跳出循环;
- 3.将embedding恢复为(1)时的值;
- 4.根据2c的梯度对参数进行更新。
可以看到,在循环中r是逐渐累加的,要注意的是最后更新参数只使用最后一个x+r算出来的梯度。
pgd代码实现如下:
class PGD: def __init__(self, model, eps=1., alpha=0.3): self.model = ( model.module if hasattr(model, "module") else model ) self.eps = eps self.alpha = alpha self.emb_backup = {} self.grad_backup = {} def attack(self, emb_name='word_embeddings', is_first_attack=False): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: if is_first_attack: self.emb_backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0 and not torch.isnan(norm): r_at = self.alpha * param.grad / norm param.data.add_(r_at) param.data = self.project(name, param.data) def restore(self, emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: assert name in self.emb_backup param.data = self.emb_backup[name] self.emb_backup = {} def project(self, param_name, param_data): r = param_data - self.emb_backup[param_name] if torch.norm(r) > self.eps: r = self.eps * r / torch.norm(r) return self.emb_backup[param_name] + r def backup_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: self.grad_backup[name] = param.grad.clone() def restore_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: param.grad = self.grad_backup[name]
pgd应用代码如下:
loss = model(**batch_data)[0] loss.backward() pgd.backup_grad() for _t in range(pgd_k): pgd.attack(is_first_attack=(_t == 0)) if _t != pgd_k - 1: model.zero_grad() else: pgd.restore_grad() loss_adv = model(**batch_data)[0] loss_adv.backward() pgd.restore() optimizer.step()
注:在torch中,每次迭代时,如果不把模型的梯度清零,会默认将模型每次迭代的梯度累加的。
这篇关于对抗训练fgm和pgd原理和源码分析的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-11cursor试用出现:Too many free trial accounts used on this machine 的解决方法
- 2025-01-11百万架构师第十四课:源码分析:Spring 源码分析:深入分析IOC那些鲜为人知的细节|JavaGuide
- 2025-01-11不得不了解的高效AI办公工具API
- 2025-01-102025 蛇年,J 人直播带货内容审核团队必备的办公软件有哪 6 款?
- 2025-01-10高效运营背后的支柱:文档管理优化指南
- 2025-01-10年末压力山大?试试优化你的文档管理
- 2025-01-10跨部门协作中的进度追踪重要性解析
- 2025-01-10总结 JavaScript 中的变体函数调用方式
- 2025-01-10HR团队如何通过数据驱动提升管理效率?6个策略
- 2025-01-10WBS实战指南:如何一步步构建高效项目管理框架?