1
0
Fork 0

Fix yolov3 example (#577)

pull/578/head
Mischa Untaga 2023-02-21 18:24:00 +01:00 committed by GitHub
parent 8550b3e168
commit 14bb2c40a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -7,9 +7,9 @@ import cv2
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, Conv2d
from tinygrad.nn import BatchNorm2d, Conv2d, optim
from tinygrad.helpers import getenv
from extra.utils import fetch, get_parameters
from extra.utils import fetch
from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d
np.set_printoptions(suppress=True)
GPU = getenv("GPU")
@ -562,7 +562,7 @@ if __name__ == "__main__":
# model.load_weights('https://pjreddie.com/media/files/yolov3-tiny.weights') # tiny model
if GPU:
params = get_parameters(model)
params = optim.get_parameters(model)
[x.gpu_() for x in params]
if len(sys.argv) > 1:
@ -592,7 +592,7 @@ if __name__ == "__main__":
cv2.destroyAllWindows()
elif url.startswith('http'):
img_stream = io.BytesIO(fetch(url))
img = cv2.imdecode(np.fromstring(img_stream.read(), np.uint8), 1)
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
else:
img = cv2.imread(url)