pytorch/test/functorch/attn_positional.py

94 lines
4.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
import math
class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_probs_dropout_prob,
position_embedding_type=None, max_position_embeddings=None):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({num_attention_heads})"
)
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type
if self.position_embedding_type is not None:
assert max_position_embeddings is not None
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
past_key_value=None,
):
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
q = self.transpose_for_scores(q)
k = self.transpose_for_scores(k)
v = self.transpose_for_scores(v)
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(q, k.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if self.position_embedding_type is not None:
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=q.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", q, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", q, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", k, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_probs = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# # This is actually dropping out entire tokens to attend to, which might
# # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, v)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer