1
0
Fork 0
tinygrab/test/test_efficientnet.py

96 lines
2.5 KiB
Python
Raw Normal View History

import ast
import pathlib
import sys
import unittest
import numpy as np
from PIL import Image
from models.efficientnet import EfficientNet
2022-06-05 18:12:43 -06:00
from models.vit import ViT
from tinygrad.tensor import Tensor
2021-10-30 17:52:40 -06:00
def _load_labels():
labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt'
return ast.literal_eval(labels_filename.read_text())
_LABELS = _load_labels()
2022-06-11 15:30:26 -06:00
def preprocess(img, new=False):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = np.array(img)
y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2
img = img[y0: y0 + 224, x0: x0 + 224]
# low level preprocess
2022-06-11 15:30:26 -06:00
if new:
img = img.astype(np.float32)
img -= [127.0, 127.0, 127.0]
img /= [128.0, 128.0, 128.0]
img = img[None]
else:
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
img /= 255.0
img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
2022-06-11 14:17:15 -06:00
return img
2022-06-11 15:30:26 -06:00
2022-06-11 14:17:15 -06:00
def _infer(model: EfficientNet, img, bs=1):
2022-06-11 15:30:26 -06:00
img = preprocess(img)
# run the net
2022-06-08 10:36:31 -06:00
if bs > 1: img = img.repeat(bs, axis=0)
out = model.forward(Tensor(img)).cpu()
2022-06-08 10:36:31 -06:00
return _LABELS[np.argmax(out.data[0])]
2022-06-11 15:30:26 -06:00
chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
class TestEfficientNet(unittest.TestCase):
2022-06-05 18:12:43 -06:00
@classmethod
def setUpClass(cls):
cls.model = EfficientNet(number=0)
cls.model.load_from_pretrained()
2022-06-05 18:16:40 -06:00
@classmethod
def tearDownClass(cls):
del cls.model
2022-06-05 18:12:43 -06:00
def test_chicken(self):
label = _infer(self.model, chicken_img)
self.assertEqual(label, "hen")
2022-06-08 10:36:31 -06:00
def test_chicken_bigbatch(self):
2022-06-11 14:17:15 -06:00
label = _infer(self.model, chicken_img, 4)
2022-06-08 10:36:31 -06:00
self.assertEqual(label, "hen")
2022-06-05 18:12:43 -06:00
def test_car(self):
label = _infer(self.model, car_img)
self.assertEqual(label, "sports car, sport car")
class TestViT(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = ViT()
cls.model.load_from_pretrained()
2022-06-05 18:16:40 -06:00
@classmethod
def tearDownClass(cls):
del cls.model
def test_chicken(self):
2022-06-05 18:12:43 -06:00
label = _infer(self.model, chicken_img)
self.assertEqual(label, "cock")
2022-06-05 18:12:43 -06:00
def test_car(self):
label = _infer(self.model, car_img)
self.assertEqual(label, "racer, race car, racing car")
if __name__ == '__main__':
unittest.main()