1
0
Fork 0
tinygrab/extra/training.py

79 lines
2.4 KiB
Python

import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad.jit import TinyJit
def train(
model,
X_train,
Y_train,
optim,
steps,
BS=128,
lossfn=lambda out, y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x,
target_transform=lambda x: x,
noloss=False,
):
@TinyJit
def train_step(x, y):
# network
out = model.forward(x) if hasattr(model, "forward") else model(x)
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss:
del loss
optim.step()
if noloss:
return (None, None)
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean()
return loss.realize(), accuracy.realize()
with Tensor.train():
losses, accuracies = [], []
for i in (t := trange(steps, disable=CI)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp]))
loss, accuracy = train_step(x, y)
# printing
if not noloss:
loss, accuracy = loss.numpy(), accuracy.numpy()
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
return [losses, accuracies]
def evaluate(
model,
X_test,
Y_test,
num_classes=None,
BS=128,
return_predict=False,
transform=lambda x: x,
target_transform=lambda y: y,
):
Tensor.training = False
def numpy_eval(Y_test, num_classes):
Y_test_preds_out = np.zeros(list(Y_test.shape) + [num_classes])
for i in trange((len(Y_test) - 1) // BS + 1, disable=CI):
x = Tensor(transform(X_test[i * BS : (i + 1) * BS]))
out = model.forward(x) if hasattr(model, "forward") else model(x)
Y_test_preds_out[i * BS : (i + 1) * BS] = out.numpy()
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
Y_test = target_transform(Y_test)
return (Y_test == Y_test_preds).mean(), Y_test_preds
if num_classes is None:
num_classes = Y_test.max().astype(int) + 1
acc, Y_test_pred = numpy_eval(Y_test, num_classes)
print("test set accuracy is %f" % acc)
return (acc, Y_test_pred) if return_predict else acc