Transformer - 注意⼒机制 代码实现
flyfish
计算过程
flyfish
# -*- Coding: utf-8 -*- IMPort torch import torch.nn as nn import torch.nn.functional as F import os import math def attention(query, key, value, mask=None, dropout=None): # query的最后⼀维的⼤⼩, ⼀般情况下就等同于词嵌⼊维度, 命名为d_k d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) print("scores.shape:",scores.shape)#scores.shape: torch.Size([1, 12, 12]) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim = -1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.reGISter_buffer("pe", pe) def forward(self, x): x = x + self.pe[:, : x.size(1)].requires_grad_(False) return self.dropout(x) #在测试attention的时候需要位置编码PositionalEncoding # 词嵌⼊维度是8维 d_model = 8 # 置0⽐率为0.1 dropout = 0.1 # 句⼦最⼤⻓度 max_len=12 x = torch.zeros(1, max_len, d_model) pe = PositionalEncoding(d_model, dropout, max_len) pe_result = pe(x) print("pe_result:", pe_result) query = key = value = pe_result print("pe_result.shape:",pe_result.shape) #没有mask的输出情况 #pe_result.shape: torch.Size([1, 12, 8]) attn, p_attn = attention(query, key, value) print("no mask\n") print("attn:", attn) print("p_attn:", p_attn) #scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 除以math.sqrt(d_k) 表示这个注意力就是 缩放点积注意力,如果没有,那么就是 点积注意力 #当Q=K=V时,又叫⾃注意⼒机制 #有mask的输出情况 print("mask\n") mask = torch.zeros(1, max_len, max_len) attn, p_attn = attention(query, key, value, mask=mask) print("attn:", attn) print("p_attn:", p_attn)
pe_result: tensor([[[ 0.0000e+00, 1.1111e+00, 0.0000e+00, 1.1111e+00, 0.0000e+00, 1.1111e+00, 0.0000e+00, 1.1111e+00], [ 9.3497e-01, 6.0034e-01, 1.1093e-01, 1.1056e+00, 1.1111e-02, 1.1111e+00, 1.1111e-03, 1.1111e+00], [ 1.0103e+00, -4.6239e-01, 2.2074e-01, 1.0890e+00, 2.2221e-02, 0.0000e+00, 2.2222e-03, 1.1111e+00], [ 1.5680e-01, -1.1000e+00, 0.0000e+00, 1.0615e+00, 3.3328e-02, 0.0000e+00, 3.3333e-03, 1.1111e+00], [-8.4089e-01, -7.2627e-01, 4.3269e-01, 1.0234e+00, 4.4433e-02, 1.1102e+00, 4.4444e-03, 1.1111e+00], [-1.0655e+00, 3.1518e-01, 5.3270e-01, 0.0000e+00, 5.5532e-02, 1.1097e+00, 5.5555e-03, 1.1111e+00], [-3.1046e-01, 1.0669e+00, 6.2738e-01, 9.1704e-01, 0.0000e+00, 1.1091e+00, 6.6666e-03, 0.0000e+00], [ 7.2999e-01, 8.3767e-01, 7.1580e-01, 8.4982e-01, 7.7714e-02, 1.1084e+00, 7.7777e-03, 1.1111e+00], [ 1.0993e+00, -1.6167e-01, 7.9706e-01, 7.7412e-01, 8.8794e-02, 1.1076e+00, 8.8888e-03, 1.1111e+00], [ 4.5791e-01, -0.0000e+00, 8.7036e-01, 6.9068e-01, 9.9865e-02, 1.1066e+00, 9.9999e-03, 1.1111e+00], [-6.0447e-01, -9.3230e-01, 9.3497e-01, 6.0034e-01, 1.1093e-01, 1.1056e+00, 1.1111e-02, 1.1111e+00], [-1.1111e+00, 4.9174e-03, 9.9023e-01, 5.0400e-01, 1.2198e-01, 1.1044e+00, 1.2222e-02, 1.1110e+00]]]) pe_result.shape: torch.Size([1, 12, 8]) scores.shape: torch.Size([1, 12, 12]) no mask attn: tensor([[[ 1.0590e-01, 2.7361e-01, 4.9333e-01, 8.3999e-01, 5.0599e-02, 1.0079e+00, 5.6491e-03, 1.0138e+00], [ 2.7554e-01, 2.0916e-01, 4.9203e-01, 8.6593e-01, 5.2177e-02, 9.7066e-01, 5.6513e-03, 1.0398e+00], [ 2.8765e-01, -3.8825e-02, 4.7812e-01, 8.7535e-01, 5.4246e-02, 8.4157e-01, 5.7015e-03, 1.0659e+00], [ 9.3666e-02, -1.8286e-01, 4.8727e-01, 8.5124e-01, 5.7070e-02, 8.2547e-01, 5.9523e-03, 1.0712e+00], [-1.6747e-01, -1.0274e-01, 5.6960e-01, 7.7584e-01, 6.3699e-02, 9.6958e-01, 6.7169e-03, 1.0546e+00], [-2.2646e-01, 6.8462e-02, 5.8668e-01, 7.2227e-01, 6.3119e-02, 1.0233e+00, 6.8004e-03, 1.0310e+00], [ 8.8945e-04, 2.7654e-01, 5.3750e-01, 8.0958e-01, 5.2289e-02, 1.0259e+00, 6.1360e-03, 9.6094e-01], [ 2.2231e-01, 2.2832e-01, 5.2263e-01, 8.4111e-01, 5.4828e-02, 9.9655e-01, 5.9765e-03, 1.0298e+00], [ 2.6388e-01, 7.2239e-02, 5.3800e-01, 8.4070e-01, 5.8958e-02, 9.5033e-01, 6.2306e-03, 1.0564e+00], [ 1.2822e-01, 7.4518e-02, 5.5305e-01, 8.1381e-01, 6.0125e-02, 9.7442e-01, 6.4089e-03, 1.0462e+00], [-1.5757e-01, -1.3194e-01, 5.9562e-01, 7.6069e-01, 6.7079e-02, 9.7264e-01, 7.0187e-03, 1.0607e+00], [-2.3505e-01, 5.6245e-03, 6.0160e-01, 7.3040e-01, 6.5491e-02, 1.0176e+00, 7.0038e-03, 1.0367e+00]]]) p_attn: tensor([[[0.1488, 0.1215, 0.0514, 0.0396, 0.0698, 0.0703, 0.0875, 0.1205, 0.0790, 0.0814, 0.0544, 0.0757], [0.1170, 0.1434, 0.0757, 0.0489, 0.0590, 0.0460, 0.0642, 0.1304, 0.1161, 0.0943, 0.0527, 0.0524], [0.0716, 0.1094, 0.1341, 0.1067, 0.0716, 0.0379, 0.0407, 0.0930, 0.1221, 0.0921, 0.0713, 0.0494], [0.0597, 0.0765, 0.1155, 0.1397, 0.1127, 0.0506, 0.0359, 0.0627, 0.0918, 0.0806, 0.1056, 0.0688], [0.0692, 0.0607, 0.0509, 0.0740, 0.1475, 0.0846, 0.0509, 0.0607, 0.0692, 0.0788, 0.1342, 0.1194], [0.0887, 0.0601, 0.0343, 0.0423, 0.1076, 0.1341, 0.0721, 0.0748, 0.0591, 0.0777, 0.1057, 0.1435], [0.1232, 0.0938, 0.0411, 0.0335, 0.0722, 0.0804, 0.1351, 0.1103, 0.0722, 0.0814, 0.0633, 0.0935], [0.1124, 0.1263, 0.0623, 0.0388, 0.0571, 0.0553, 0.0731, 0.1388, 0.1134, 0.1001, 0.0571, 0.0652], [0.0758, 0.1157, 0.0841, 0.0584, 0.0670, 0.0450, 0.0492, 0.1166, 0.1429, 0.1101, 0.0763, 0.0588], [0.0822, 0.0989, 0.0668, 0.0540, 0.0803, 0.0622, 0.0584, 0.1084, 0.1158, 0.1046, 0.0879, 0.0804], [0.0548, 0.0551, 0.0515, 0.0705, 0.1364, 0.0845, 0.0454, 0.0617, 0.0801, 0.0877, 0.1499, 0.1224], [0.0763, 0.0548, 0.0357, 0.0459, 0.1213, 0.1146, 0.0669, 0.0703, 0.0616, 0.0802, 0.1224, 0.1499]]]) mask scores.shape: torch.Size([1, 12, 12]) attn: tensor([[[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185], [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185]]]) p_attn: tensor([[[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833], [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833]]])