1
0
Fork 0
tinygrab/test/models/test_bert.py

78 lines
2.4 KiB
Python

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor
import torch
def get_question_samp(bsz, seq_len, vocab_size, seed):
np.random.seed(seed)
in_ids = np.random.randint(vocab_size, size=(bsz, seq_len))
mask = np.random.choice([True, False], size=(bsz, seq_len))
seg_ids = np.random.randint(1, size=(bsz, seq_len))
return in_ids, mask, seg_ids
def set_equal_weights(mdl, torch_mdl):
from tinygrad.nn.state import get_state_dict
state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
assert len(state) == len(torch_state)
for k, v in state.items():
assert k in torch_state
torch_state[k].copy_(torch.from_numpy(v.numpy()))
torch_mdl.eval()
class TestBert(unittest.TestCase):
def test_questions(self):
from extra.models.bert import BertForQuestionAnswering
from transformers import (
BertForQuestionAnswering as TorchBertForQuestionAnswering,
)
from transformers import BertConfig
# small
config = {
"vocab_size": 24,
"hidden_size": 2,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"intermediate_size": 32,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
}
# Create in tinygrad
Tensor.manual_seed(1337)
mdl = BertForQuestionAnswering(**config)
# Create in torch
with torch.no_grad():
torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config))
set_equal_weights(mdl, torch_mdl)
seeds = (1337, 3141)
bsz, seq_len = 1, 16
for _, seed in enumerate(seeds):
in_ids, mask, seg_ids = get_question_samp(
bsz, seq_len, config["vocab_size"], seed
)
out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids))
torch_out = torch_mdl.forward(
torch.from_numpy(in_ids).long(),
torch.from_numpy(mask),
torch.from_numpy(seg_ids).long(),
)[:2]
torch_out = torch.cat(torch_out).unsqueeze(2)
np.testing.assert_allclose(
out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4
)
if __name__ == "__main__":
unittest.main()