94 lines
4.3 KiB
Python
94 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import torch
|
|
from torch import nn
|
|
import math
|
|
|
|
class BertSelfAttention(nn.Module):
|
|
def __init__(self, hidden_size, num_attention_heads,
|
|
attention_probs_dropout_prob,
|
|
position_embedding_type=None, max_position_embeddings=None):
|
|
super().__init__()
|
|
if hidden_size % num_attention_heads != 0:
|
|
raise ValueError(
|
|
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
|
|
f"heads ({num_attention_heads})"
|
|
)
|
|
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_size = int(hidden_size / num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(hidden_size, self.all_head_size)
|
|
self.key = nn.Linear(hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(attention_probs_dropout_prob)
|
|
self.position_embedding_type = position_embedding_type
|
|
|
|
if self.position_embedding_type is not None:
|
|
assert max_position_embeddings is not None
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size)
|
|
|
|
def transpose_for_scores(self, x):
|
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
past_key_value=None,
|
|
):
|
|
q = self.query(hidden_states)
|
|
k = self.key(hidden_states)
|
|
v = self.value(hidden_states)
|
|
|
|
q = self.transpose_for_scores(q)
|
|
k = self.transpose_for_scores(k)
|
|
v = self.transpose_for_scores(v)
|
|
|
|
if past_key_value is not None:
|
|
k = torch.cat([past_key_value[0], k], dim=2)
|
|
v = torch.cat([past_key_value[1], v], dim=2)
|
|
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(q, k.transpose(-1, -2))
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
|
|
if self.position_embedding_type is not None:
|
|
seq_length = hidden_states.size()[1]
|
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
|
distance = position_ids_l - position_ids_r
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
|
positional_embedding = positional_embedding.to(dtype=q.dtype) # fp16 compatibility
|
|
|
|
if self.position_embedding_type == "relative_key":
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", q, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores
|
|
elif self.position_embedding_type == "relative_key_query":
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", q, positional_embedding)
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", k, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
|
|
|
attention_probs = attention_scores
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
|
# # This is actually dropping out entire tokens to attend to, which might
|
|
# # seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
|
|
context_layer = torch.matmul(attention_probs, v)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
return context_layer
|