1
0
Fork 0
tinygrab/examples/serious_mnist.py

153 lines
5.2 KiB
Python

#!/usr/bin/env python
# inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
import sys
import numpy as np
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, optim
from tinygrad.helpers import getenv
from extra.datasets import fetch_mnist
from extra.augment import augment_img
from extra.training import train, evaluate
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
class SqueezeExciteBlock2D:
def __init__(self, filters):
self.filters = filters
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters // 32)
self.bias1 = Tensor.scaled_uniform(1, self.filters // 32)
self.weight2 = Tensor.scaled_uniform(self.filters // 32, self.filters)
self.bias2 = Tensor.scaled_uniform(1, self.filters)
def __call__(self, input):
se = input.avg_pool2d(
kernel_size=(input.shape[2], input.shape[3])
) # GlobalAveragePool2D
se = se.reshape(shape=(-1, self.filters))
se = se.dot(self.weight1) + self.bias1
se = se.relu()
se = se.dot(self.weight2) + self.bias2
se = se.sigmoid().reshape(shape=(-1, self.filters, 1, 1)) # for broadcasting
se = input.mul(se)
return se
class ConvBlock:
def __init__(self, h, w, inp, filters=128, conv=3):
self.h, self.w = h, w
self.inp = inp
# init weights
self.cweights = [
Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv)
for i in range(3)
]
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
# init layers
self._bn = BatchNorm2d(128)
self._seb = SqueezeExciteBlock2D(filters)
def __call__(self, input):
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
for cweight, cbias in zip(self.cweights, self.cbiases):
x = x.pad2d(padding=[1, 1, 1, 1]).conv2d(cweight).add(cbias).relu()
x = self._bn(x)
x = self._seb(x)
return x
class BigConvNet:
def __init__(self):
self.conv = [
ConvBlock(28, 28, 1),
ConvBlock(28, 28, 128),
ConvBlock(14, 14, 128),
]
self.weight1 = Tensor.scaled_uniform(128, 10)
self.weight2 = Tensor.scaled_uniform(128, 10)
def parameters(self):
if DEBUG: # keeping this for a moment
pars = [par for par in get_parameters(self) if par.requires_grad]
no_pars = 0
for par in pars:
print(par.shape)
no_pars += np.prod(par.shape)
print("no of parameters", no_pars)
return pars
else:
return get_parameters(self)
def save(self, filename):
with open(filename + ".npy", "wb") as f:
for par in get_parameters(self):
# if par.requires_grad:
np.save(f, par.numpy())
def load(self, filename):
with open(filename + ".npy", "rb") as f:
for par in get_parameters(self):
# if par.requires_grad:
try:
par.numpy()[:] = np.load(f)
if GPU:
par.gpu()
except:
print("Could not load parameter")
def forward(self, x):
x = self.conv[0](x)
x = self.conv[1](x)
x = x.avg_pool2d(kernel_size=(2, 2))
x = self.conv[2](x)
x1 = x.avg_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global
x2 = x.max_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
return xo
if __name__ == "__main__":
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
BS = 32
lmbd = 0.00025
lossfn = (
lambda out, y: out.sparse_categorical_crossentropy(y)
+ lmbd * (model.weight1.abs() + model.weight2.abs()).sum()
)
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
steps = len(X_train) // BS
np.random.seed(1337)
if QUICK:
steps = 1
X_test, Y_test = X_test[:BS], Y_test[:BS]
model = BigConvNet()
if len(sys.argv) > 1:
try:
model.load(sys.argv[1])
print('Loaded weights "' + sys.argv[1] + '", evaluating...')
evaluate(model, X_test, Y_test, BS=BS)
except:
print('could not load weights "' + sys.argv[1] + '".')
if GPU:
params = get_parameters(model)
[x.gpu_() for x in params]
for lr, epochs in zip(lrs, epochss):
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
# first epoch without augmentation
X_aug = X_train if epoch == 1 else augment_img(X_train)
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
accuracy = evaluate(model, X_test, Y_test, BS=BS)
model.save(f"examples/checkpoint{accuracy * 1e6:.0f}")