手写多头注意力(MHA)的实现

·2713·6 分钟·
AI摘要: 本文介绍了手写多头注意力(MHA)的实现。文章首先定义了MultiHeadAttention类,该类用于处理多维输入数据并生成输出。在实现过程中,作者详细描述了如何通过线性变换、Scaled Dot-Product Attention和Softmax操作来构建注意力机制。实验结果表明,使用einsum表示法可以简化代码编写,提高可读性。

import torch

import torch.nn as nn

import torch.nn.functional as F



class MultiHeadAttention(nn.Module):

    def __init__(self, embed_size, num_heads):

        super(MultiHeadAttention, self).__init__()

        self.embed_size = embed_size

        self.num_heads = num_heads

        self.head_dim = embed_size // num_heads



        assert (

            self.head_dim * num_heads == embed_size

        ), "Embedding size needs to be divisible by num_heads"



        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)



    def forward(self, values, keys, query, mask):

        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]



        # 将values, keys, queries分成多个头

        values = values.reshape(N, value_len, self.num_heads, self.head_dim)

        keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)

        queries = query.reshape(N, query_len, self.num_heads, self.head_dim)

	   # 进行线性变换

        values = self.values(values)

        keys = self.keys(keys)

        queries = self.queries(queries)



        # Scaled dot-product attention(使用了爱因斯坦求和)

        # 这里直接求出kq矩阵

        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:

            attention = attention.masked_fill(mask == 0, float("-1e20"))

		# 使用softmax来归一化成注意力分数,分母是为了防止注意力分数相差过大

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

		# 注意力分数和对应的value求加权和

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(

            N, query_len, self.num_heads * self.head_dim

        )



        out = self.fc_out(out)

        return out





embed_size = 256

num_heads = 8

values = torch.randn(64, 10, embed_size)

keys = torch.randn(64, 10, embed_size)

query = torch.randn(64, 10, embed_size)

mask = None  # 可选的mask



multihead_attention = MultiHeadAttention(embed_size, num_heads)

output = multihead_attention(values, keys, query, mask)

print(output.shape)  # 期望输出: torch.Size([64, 10, 256])

实验结果:

image-20240714122957257

Einsum 表示法是对张量的复杂操作的一种优雅方式,本质上是使用特定领域的语言。 一旦理解并掌握了 einsum,可以帮助我们更快地编写更简洁高效的代码。

基本内容:当两个变量具有相同的角标时,则遍历求和。在此情况下,求和号可以省略。

在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,使用einsum更加简单,但是可读性很差

Kaggle学习赛初探