小黑算法成长日记23:selfAttention与multiHeadAttention
2022/1/2 12:37:34
本文主要是介绍小黑算法成长日记23:selfAttention与multiHeadAttention,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
SelfAttention操作
从单个字的角度:
q
i
=
h
i
W
Q
,
k
j
=
h
j
W
K
,
v
j
=
h
j
W
V
q_i = h_iW_Q,k_j = h_jW_K,v_j = h_jW_V
qi=hiWQ,kj=hjWK,vj=hjWV
e
i
j
=
q
i
k
j
T
e_{ij} = q_ik_j^T
eij=qikjT
α
i
=
S
o
f
t
m
a
x
(
[
e
i
,
1
,
.
.
.
,
e
i
,
T
]
)
\alpha_i = Softmax([e_{i,1},...,e_{i,T}])
αi=Softmax([ei,1,...,ei,T])
h
i
′
=
(
∑
j
=
1
T
α
i
,
j
v
j
)
W
0
h'_i = (\sum_{j=1}^T \alpha_{i,j}v_j)W_0
hi′=(∑j=1Tαi,jvj)W0
矩阵的形式:
Q
=
H
W
Q
,
K
=
H
W
K
,
V
=
H
W
V
Q = HW_Q,K = HW_K,V = HW_V
Q=HWQ,K=HWK,V=HWV
E
=
Q
K
T
E = QK^T
E=QKT
E
′
=
S
o
f
t
m
a
x
(
E
)
E' = Softmax(E)
E′=Softmax(E)
H
′
=
E
′
V
H' = E'V
H′=E′V
单头selfAttention
import math import torch import torch.nn as nn class SelfAttention(nn.Module): def __init__(self,d_model,d_head): super(SelfAttention,self).__init__() self.w_q = nn.Linear(d_model,d_head) self.w_k = nn.Linear(d_model,d_head) self.w_v = nn.Linear(d_model,d_head) self.w_o = nn.Linear(d_head,d_model) def forward(self,x): # x:[batch_size,max_len,model_dim] # q,k,v:[batch_size,max_len,d_head] q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) attn_score = torch.matmul(q,k.permute(0,2,1)) # 注意这里不是reshape attn_score = torch.softmax(attn_score,dim = -1) # [batch_size,max_len,max_len] output = torch.matmul(attn_score,v) # [batch_size,max_len,d_head] return self.w_o(output) x = torch.randn(3,9,100) model = SelfAttention(100,80) model(x).shape
多头selfAttention
# 多头selfattention class MultiHeadSelfAttention(nn.Module): def __init__(self,d_model = 768,d_head = 64): super(MultiHeadSelfAttention,self).__init__() assert d_model % d_head == 0 self.w_q = nn.Linear(d_model,d_model) self.w_k = nn.Linear(d_model,d_model) self.w_v = nn.Linear(d_model,d_model) self.w_o = nn.Linear(d_model,d_model) self.n_heads = int(d_model // d_head) self.d_model = d_model self.d_head = d_head def forward(self,x,mask = None): batch_size = x.shape[0] max_len = x.shape[1] q = self.w_q(x).view(batch_size,max_len,self.n_heads,self.d_head) k = self.w_k(x).view(batch_size,max_len,self.n_heads,self.d_head) v = self.w_v(x).view(batch_size,max_len,self.n_heads,self.d_head) q = q.permute(0,2,1,3) k = k.permute(0,2,1,3) v = v.permute(0,2,1,3) # [batch_size,num_head,max_len,d_head] attn_score = torch.matmul(q,k.permute(0,1,3,2)) if mask is not None: mask = mask.unsqueeze(1).unsqueeze(-1) # [batch_size,1,max_len,1] attn_score = attn_score.masked_fill(mask == 0,-1e-25) attn_score = torch.softmax(attn_score,-1) # [batch_size,num_head,max_len,max_len] out = torch.matmul(attn_score,v).permute(0,2,1,3) out = out.contiguous().view(batch_size,max_len,-1) return self.w_o(out) if __name__ == "__main__": x = torch.randn(2, 9, 768) mask = torch.tensor([ [1, 1, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0], ]).bool() model = MultiHeadSelfAttention() print(model(x,mask).shape)
这篇关于小黑算法成长日记23:selfAttention与multiHeadAttention的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-11有哪些好用的家政团队管理工具?
- 2025-01-11营销人必看的GTM五个指标
- 2025-01-11办公软件在直播电商前期筹划中的应用与推荐
- 2025-01-11提升组织效率:上级管理者如何优化跨部门任务分配
- 2025-01-11酒店精细化运营背后的协同工具支持
- 2025-01-11跨境电商选品全攻略:工具使用、市场数据与选品策略
- 2025-01-11数据驱动酒店管理:在线工具的核心价值解析
- 2025-01-11cursor试用出现:Too many free trial accounts used on this machine 的解决方法
- 2025-01-11百万架构师第十四课:源码分析:Spring 源码分析:深入分析IOC那些鲜为人知的细节|JavaGuide
- 2025-01-11不得不了解的高效AI办公工具API