1
0
Fork 0
tinygrab/test/graph_batchnorm.py

62 lines
1.8 KiB
Python
Raw Normal View History

2022-07-02 18:11:12 -06:00
from tinygrad.tensor import Tensor
2022-08-20 08:40:56 -06:00
from tinygrad.nn import Conv2d, BatchNorm2D, optim
2022-07-02 18:11:12 -06:00
from extra.utils import get_parameters # TODO: move to optim
import unittest
2022-07-02 18:54:04 -06:00
def model_step(lm):
Tensor.training = True
x = Tensor.ones(8,12,128,256, requires_grad=False)
optimizer = optim.SGD(get_parameters(lm), lr=0.001)
loss = lm.forward(x).sum()
2022-07-02 18:54:04 -06:00
optimizer.zero_grad()
loss.backward()
2022-07-17 16:38:43 -06:00
del x,loss
2022-07-02 18:54:04 -06:00
optimizer.step()
Tensor.training = False
2022-07-02 18:11:12 -06:00
class TestBatchnorm(unittest.TestCase):
2022-07-02 18:54:04 -06:00
def test_conv(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
def forward(self, x):
return self.c(x).relu()
lm = LilModel()
model_step(lm)
2022-07-02 18:11:12 -06:00
2022-07-02 18:54:04 -06:00
def test_two_conv(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
def forward(self, x):
return self.c2(self.c(x)).relu()
lm = LilModel()
model_step(lm)
2022-08-20 15:04:33 -06:00
def test_two_conv_bn(self):
class LilModel:
def __init__(self):
self.c = Conv2d(12, 24, 3, padding=1, bias=False)
self.bn = BatchNorm2D(24, track_running_stats=False)
self.c2 = Conv2d(24, 32, 3, padding=1, bias=False)
self.bn2 = BatchNorm2D(32, track_running_stats=False)
def forward(self, x):
x = self.bn(self.c(x)).relu()
return self.bn2(self.c2(x)).relu()
lm = LilModel()
model_step(lm)
2022-07-02 18:54:04 -06:00
def test_conv_bn(self):
2022-07-02 18:11:12 -06:00
class LilModel:
def __init__(self):
self.c = Conv2d(12, 32, 3, padding=1, bias=False)
self.bn = BatchNorm2D(32, track_running_stats=False)
def forward(self, x):
return self.bn(self.c(x)).relu()
lm = LilModel()
2022-07-02 18:54:04 -06:00
model_step(lm)
2022-07-02 18:11:12 -06:00
if __name__ == '__main__':
unittest.main()