手写多头注意力(MHA)的实现
AI摘要: 本文介绍了手写多头注意力(MHA)的实现。文章首先定义了MultiHeadAttention类,该类用于处理多维输入数据并生成输出。在实现过程中,作者详细描述了如何通过线性变换、Scaled Dot-Product Attention和Softmax操作来构建注意力机制。实验结果表明,使用einsum表示法可以简化代码编写,提高可读性。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 将values, keys, queries分成多个头
values = values.reshape(N, value_len, self.num_heads, self.head_dim)
keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
queries = query.reshape(N, query_len, self.num_heads, self.head_dim)
# 进行线性变换
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# Scaled dot-product attention(使用了爱因斯坦求和)
# 这里直接求出kq矩阵
attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
attention = attention.masked_fill(mask == 0, float("-1e20"))
# 使用softmax来归一化成注意力分数,分母是为了防止注意力分数相差过大
attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)
# 注意力分数和对应的value求加权和
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.num_heads * self.head_dim
)
out = self.fc_out(out)
return out
embed_size = 256
num_heads = 8
values = torch.randn(64, 10, embed_size)
keys = torch.randn(64, 10, embed_size)
query = torch.randn(64, 10, embed_size)
mask = None # 可选的mask
multihead_attention = MultiHeadAttention(embed_size, num_heads)
output = multihead_attention(values, keys, query, mask)
print(output.shape) # 期望输出: torch.Size([64, 10, 256])
实验结果:
Einsum 表示法是对张量的复杂操作的一种优雅方式,本质上是使用特定领域的语言。 一旦理解并掌握了 einsum,可以帮助我们更快地编写更简洁高效的代码。
基本内容:当两个变量具有相同的角标时,则遍历求和。在此情况下,求和号可以省略。
在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,使用einsum更加简单,但是可读性很差