多头:
py展开代码import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"
        
        # 线性变换得到 Q, K, V
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        
        # 输出线性层
        self.fc_out = nn.Linear(embed_size, embed_size)
    
    def forward(self, x):
        # x shape: (N, seq_len, embed_size)
        N = x.shape[0]
        seq_len = x.shape[1]
        
        # 线性变换得到 Q, K, V
        values = self.values(x)  # (N, seq_len, embed_size)
        keys = self.keys(x)      # (N, seq_len, embed_size)
        queries = self.queries(x) # (N, seq_len, embed_size)
        
        # 分割多头
        values = values.reshape(N, seq_len, self.heads, self.head_dim)
        keys = keys.reshape(N, seq_len, self.heads, self.head_dim)
        queries = queries.reshape(N, seq_len, self.heads, self.head_dim)
        
        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, seq_len, heads, head_dim)
        # keys shape: (N, seq_len, heads, head_dim)
        # energy shape: (N, heads, seq_len, seq_len)
        
        # 缩放点积注意力
        attention = F.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        
        # 应用注意力到values上
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        # attention shape: (N, heads, seq_len, seq_len)
        # values shape: (N, seq_len, heads, head_dim)
        # out shape: (N, seq_len, heads, head_dim)
        
        # 合并多头
        out = out.reshape(N, seq_len, self.embed_size)
        
        # 输出线性变换
        out = self.fc_out(out)
        
        return out
简化为单头:
py展开代码class SimpleSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.q = nn.Linear(embed_size, embed_size)
        self.k = nn.Linear(embed_size, embed_size)
        self.v = nn.Linear(embed_size, embed_size)
        
    def forward(self, x):
        # x shape: (N, seq_len, embed_size)
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1) ** 0.5)
        attention = F.softmax(scores, dim=-1)
        
        # 应用注意力
        out = torch.matmul(attention, V)
        return out


本文作者:Dong
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!