1
0
Fork 0
tinygrab/examples/beautiful_mnist.py

65 lines
2.2 KiB
Python

# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from extra.datasets import fetch_mnist
from tqdm import trange
class Model:
def __init__(self):
self.layers: List[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5),
Tensor.relu,
nn.Conv2d(32, 32, 5),
Tensor.relu,
nn.BatchNorm2d(32),
Tensor.max_pool2d,
nn.Conv2d(32, 64, 3),
Tensor.relu,
nn.Conv2d(64, 64, 3),
Tensor.relu,
nn.BatchNorm2d(64),
Tensor.max_pool2d,
lambda x: x.flatten(1),
nn.Linear(576, 10),
]
def __call__(self, x: Tensor) -> Tensor:
return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))
# TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int
@TinyJit
def train_step(samples: Tensor) -> Tensor:
with Tensor.train():
opt.zero_grad()
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = (
model(X_train[samples])
.sparse_categorical_crossentropy(Y_train[samples])
.backward()
)
opt.step()
return loss.realize()
@TinyJit
def get_test_acc() -> Tensor:
return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize()
test_acc = float("nan")
for i in (t := trange(70)):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
samples = Tensor.randint(
512, high=X_train.shape[0]
) # TODO: put this in the JIT when rand is fixed
loss = train_step(samples)
if i % 10 == 9:
test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")