03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现
2022/7/27 23:22:57
本文主要是介绍03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
3:20 来个赞
24:43 弹幕,是否懂了
QKV 相乘(QKV 同源),QK 相乘得到相似度A,AV 相乘得到注意力值 Z
- 第一步实现一个自注意力机制
自注意力计算
def self_attention(query, key, value, dropout=None, mask=None): d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # mask的操作在QK之后,softmax之前 if mask is not None: mask.cuda() scores = scores.masked_fill(mask == 0, -1e9) self_attn = F.softmax(scores, dim=-1) if dropout is not None: self_attn = dropout(self_attn) return torch.matmul(self_attn, value), self_attn
多头注意力
# PYthon/PYtorch/你看的这个模型的理论 class MultiHeadAttention(nn.Module): def __init__(self): super(MultiHeadAttention, self).__init__() def forward(self, head, d_model, query, key, value, dropout=0.1,mask=None): """ :param head: 头数,默认 8 :param d_model: 输入的维度 512 :param query: Q :param key: K :param value: V :param dropout: :param mask: :return: """ assert (d_model % head == 0) self.d_k = d_model // head self.head = head self.d_model = d_model self.linear_query = nn.Linear(d_model, d_model) self.linear_key = nn.Linear(d_model, d_model) self.linear_value = nn.Linear(d_model, d_model) # 自注意力机制的 QKV 同源,线性变换 self.linear_out = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(p=dropout) self.attn = None # if mask is not None: # # 多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k] # # 再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第一维添加一维,与后面的self attention计算维度一样 # mask = mask.unsqueeze(1) n_batch = query.size(0) # 多头需要对这个 X 切分成多头 # query==key==value # [b,1,512] # [b,8,1,64] # [b,32,512] # [b,8,32,64] query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64] key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64] value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2) # [b, 8, 32, 64] x, self.attn = self_attention(query, key, value, dropout=self.dropout, mask=mask) # [b,8,32,64] # [b,32,512] # 变为三维, 或者说是concat head x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k) return self.linear_out(x)
这篇关于03 Transformer 中的多头注意力(Multi-Head Attention)Pytorch代码实现的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享