pytorch/test/functorch/attn_ft.py

141 lines
7.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from functorch.dim import dims, dimlists, softmax, cat
import math
class Linear(nn.Linear):
def forward(self, input):
ci, co = dims()
b = dimlists()
result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
return result.order(b, co)
class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_probs_dropout_prob, position_embedding_type=None,
max_position_embeddings=None, linear=Linear):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({num_attention_heads})"
)
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = linear(hidden_size, self.all_head_size)
self.key = linear(hidden_size, self.all_head_size)
self.value = linear(hidden_size, self.all_head_size)
self.dropout_prob = attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type
if self.position_embedding_type is not None:
assert max_position_embeddings is not None
self.max_position_embeddings = max_position_embeddings
self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size)
def forward(
self,
hidden_states,
past_key_value=None,
):
# first run the encoding linear layers for q, k, v normally
# the meaning of a linear layer is well understood, so no need to use explicit dimensions
q = self.query(hidden_states)
k = self.key(hidden_states)
v = self.value(hidden_states)
# introduce values that represent each dimension. dimensions are 'first class'
# becaue they are actual python values introduced here
batch, query_sequence, key_sequence, heads, features = dims()
heads.size = self.num_attention_heads
# bind the positional dimensions in k, q, and v against
# our values. the sizes of each dimension are determined by this binding
# and when a dimension is used twice (e.g. batch), its size against both
# uses is checked for consistency.
# The group (heads, features) splits apart a single positional dimension
# into two dimensions. Since heads.size*features.size == q.size(2)
# and we specified heads.size, features.size is inferred here.
q = q[batch, query_sequence, [heads, features]]
k = k[batch, key_sequence, [heads, features]]
v = v[batch, key_sequence, [heads, features]]
# this option allows the model to attend to not just the elements of the current sequence
# but the previouse elements as well as additional tokens.
if past_key_value is not None:
extended_key_sequence = dims()
key_past = past_key_value[0][batch, heads, key_sequence, features]
value_past = past_key_value[1][batch, heads, key_sequence, features]
# cat introduces a new dimension exteneded_key_sequence, becuase it is twice as long
# as the original key_sequence
k = cat([key_past, k], key_sequence, extended_key_sequence)
v = cat([value_past, v], key_sequence, extended_key_sequence)
# for the rest of the function, we will just use extended_key_sequence in lieu of
# key_sequence
key_sequence = extended_key_sequence
# Take the dot product between "query" and "key" to get the raw attention scores.
# The actual outer-product and summation are explicitly represented here,
# and like einsum, will be pattern matched to an efficient matrix multiply op.
attention_scores = (q * k).sum(features) / math.sqrt(features.size)
# relative positional embeddings gave a unique embedding based on the distance between
# key and value tokens in the sequence, e.g.
# 0 1 2 3
# -1 0 1 2
# -2 -1 0 1
# -3 -2 -1 0
if self.position_embedding_type is not None:
# the value of a dimension object when used as a tensor is the indices along its dimension
# so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
# with the distance between them
distance = query_sequence - key_sequence
assert key_sequence.size <= self.max_position_embeddings
# we can then use that as an indirect index into the embedding table values to look up the features for that index
# this is just a `gather` primitive op. The resulting tensor will
# have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
# plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
# this form of indirect indexing is more strainghtforward than either advanced indexing or torch.gather which both
# have a lot of dependencies on the positions of indexing tensors.
positional_embedding = self.distance_embedding.weight[self.max_position_embeddings - 1 + distance, features]
if self.position_embedding_type == "relative_key":
# these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
# eventhough they are degenerate matmuls
relative_position_scores = (q * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = (q * positional_embedding).sum(features)
relative_position_scores_key = (k * positional_embedding).sum(features)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_probs = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = softmax(attention_scores, dim=key_sequence)
# # This is actually dropping out entire tokens to attend to, which might
# # seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = torch.nn.functional.dropout(attention_probs, p=self.dropout_prob)
# similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
# we are weighting the values v across all keys with the attention scores.
context_layer = (attention_probs * v).sum(key_sequence)
# finally, we convert back to a standard tensor by describing the layout of dimensions.
# working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
return context_layer.order(batch, query_sequence, [heads, features])