1
0
Fork 0
tinygrab/test/test_nn.py

444 lines
14 KiB
Python
Executable File

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.helpers import CI
from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import (
BatchNorm2d,
Conv1d,
ConvTranspose1d,
Conv2d,
ConvTranspose2d,
Linear,
GroupNorm,
LayerNorm,
LayerNorm2d,
Embedding,
InstanceNorm,
)
import torch
import pytest
pytestmark = [pytest.mark.exclude_cuda]
class TestNN(unittest.TestCase):
def test_sparse_cat_cross_entropy(self):
input = torch.randn(3, 5)
target = torch.empty(3, dtype=torch.long).random_(5)
loss_fun = torch.nn.CrossEntropyLoss(reduction="mean")
loss = loss_fun(input, target)
input_tiny = Tensor(input.detach().numpy())
target_tiny = Tensor(target.detach().numpy())
loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny)
np.testing.assert_allclose(
loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6
)
def test_batchnorm2d(self, training=False):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad
Tensor.training = training
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
np.testing.assert_allclose(
bn.running_mean.numpy(),
tbn.running_mean.detach().numpy(),
rtol=1e-5,
atol=1e-6,
)
np.testing.assert_allclose(
bn.running_var.numpy(),
tbn.running_var.detach().numpy(),
rtol=1e-5,
atol=1e-6,
)
# trial
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.numpy()))
# close
np.testing.assert_allclose(
outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6
)
np.testing.assert_allclose(
bn.running_mean.numpy(),
tbn.running_mean.detach().numpy(),
rtol=1e-5,
atol=1e-6,
)
np.testing.assert_allclose(
bn.running_var.numpy(),
tbn.running_var.detach().numpy(),
rtol=1e-5,
atol=1e-6,
)
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
def test_linear(self):
def _test_linear(x):
# create in tinygrad
model = Linear(in_dim, out_dim)
z = model(x)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
torch_layer.weight[:] = torch.tensor(
model.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(
model.bias.numpy(), dtype=torch.float32
)
torch_x = torch.tensor(x.numpy(), dtype=torch.float32)
torch_z = torch_layer(torch_x)
# test
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
BS, T, in_dim, out_dim = 4, 2, 8, 16
_test_linear(Tensor.randn(BS, in_dim))
_test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims
def test_conv1d(self):
BS, C1, W = 4, 16, 224 // 4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(
C1, C2, kernel_size=K, stride=S, padding=P
).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
def test_conv2d(self):
BS, C1, H, W = 4, 16, 224 // 4, 224 // 4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(
C1, C2, kernel_size=K, stride=S, padding=P
).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
@unittest.skipIf(
Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends"
)
def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16
C2, K, S, P = 8, 3, 1, 1
old_wino = Tensor.wino
Tensor.wino = True
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
layer.weight.requires_grad = True
layer.bias.requires_grad = True
# create in torch
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight = torch.nn.Parameter(
torch.tensor(layer.weight.numpy(), dtype=torch.float32)
)
torch_layer.bias = torch.nn.Parameter(
torch.tensor(layer.bias.numpy(), dtype=torch.float32)
)
# test
x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
m = z.mean()
m.backward()
gw = layer.weight.grad.realize()
gb = layer.bias.grad.realize()
gx = x.grad.realize()
torch_z.mean().backward()
np.testing.assert_allclose(
gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5
)
np.testing.assert_allclose(
gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5
)
np.testing.assert_allclose(
gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5
)
Tensor.wino = old_wino
@unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
def test_conv_transpose1d(self):
BS, C1, W = 4, 16, 224 // 4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose1d(
C1, C2, kernel_size=K, stride=S, padding=P
).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
@unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224 // 4, 224 // 4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose2d(
C1, C2, kernel_size=K, stride=S, padding=P
).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5
)
def test_groupnorm(self):
BS, H, W, C, G = 20, 10, 10, 6, 3
# create in tinygrad
layer = GroupNorm(G, C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.GroupNorm(G, C).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(BS, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3
)
def test_layernorm(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = LayerNorm([H, W])
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([H, W]).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3
)
def test_layernorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = LayerNorm2d(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([C]).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3
)
def test_instancenorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = InstanceNorm(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3
)
def test_instancenorm_3d(self):
N, C, D, H, W = 20, 5, 3, 10, 10
# create in tinygrad
layer = InstanceNorm(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, D, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3
)
def test_embedding(self):
B, T, C, VS = 4, 10, 20, 28
# create in tinygrad
layer = Embedding(VS, C)
with torch.no_grad():
torch_layer = torch.nn.Embedding(VS, C).eval()
torch_layer.weight[:] = torch.tensor(
layer.weight.numpy(), dtype=torch.float32
)
# test
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
z = layer(x)
torch_x = torch.tensor(x.numpy().astype(np.int32))
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8
)
# test with jit enabled
@TinyJit
def layer_jit(x):
return layer(x).realize()
for _ in range(3):
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
z = layer_jit(x)
torch_x = torch.tensor(x.numpy().astype(np.int32))
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(
z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8
)
if __name__ == "__main__":
unittest.main()