109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
# load weights from
|
|
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
|
# a rough copy of
|
|
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
|
import sys
|
|
import ast
|
|
import time
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import getenv, fetch, Timing
|
|
from tinygrad.jit import TinyJit
|
|
from extra.models.efficientnet import EfficientNet
|
|
|
|
np.set_printoptions(suppress=True)
|
|
|
|
# TODO: you should be able to put these in the jitted function
|
|
bias = Tensor([0.485, 0.456, 0.406])
|
|
scale = Tensor([0.229, 0.224, 0.225])
|
|
|
|
|
|
@TinyJit
|
|
def _infer(model, img):
|
|
img = img.permute((2, 0, 1))
|
|
img = img / 255.0
|
|
img = img - bias.reshape((1, -1, 1, 1))
|
|
img = img / scale.reshape((1, -1, 1, 1))
|
|
return model.forward(img).realize()
|
|
|
|
|
|
def infer(model, img):
|
|
# preprocess image
|
|
aspect_ratio = img.size[0] / img.size[1]
|
|
img = img.resize(
|
|
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
|
)
|
|
|
|
img = np.array(img)
|
|
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
|
retimg = img = img[y0 : y0 + 224, x0 : x0 + 224]
|
|
|
|
# if you want to look at the image
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
plt.imshow(img)
|
|
plt.show()
|
|
"""
|
|
|
|
# run the net
|
|
out = _infer(model, Tensor(img.astype("float32"))).numpy()
|
|
|
|
# if you want to look at the outputs
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
plt.plot(out[0])
|
|
plt.show()
|
|
"""
|
|
return out, retimg
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# instantiate my net
|
|
model = EfficientNet(getenv("NUM", 0))
|
|
model.load_from_pretrained()
|
|
|
|
# category labels
|
|
lbls = ast.literal_eval(
|
|
fetch(
|
|
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
|
).read_text()
|
|
)
|
|
|
|
# load image and preprocess
|
|
url = (
|
|
sys.argv[1]
|
|
if len(sys.argv) >= 2
|
|
else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
|
)
|
|
if url == "webcam":
|
|
import cv2
|
|
|
|
cap = cv2.VideoCapture(0)
|
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
|
while 1:
|
|
_ = cap.grab() # discard one frame to circumvent capture buffering
|
|
ret, frame = cap.read()
|
|
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
|
lt = time.monotonic_ns()
|
|
out, retimg = infer(model, img)
|
|
print(
|
|
f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms",
|
|
np.argmax(out),
|
|
np.max(out),
|
|
lbls[np.argmax(out)],
|
|
)
|
|
SCALE = 3
|
|
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
|
|
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
|
cv2.imshow("capture", retimg)
|
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
break
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
else:
|
|
img = Image.open(fetch(url))
|
|
with Timing("did inference in "):
|
|
out, _ = infer(model, img)
|
|
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|