From 0f58c4c64869d44a74a08a6d5d1362d509740d1a Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Sun, 26 Feb 2023 16:55:21 -0800 Subject: [PATCH] Cleanup yolo and remove stateless classes (#604) * Add AvgPool2d as a layer * Clean up a bit * Remove stateless layers in yolo_nn * More cleanup * Save label for test * Add test for YOLO * Test without cv2 * Don't fail if cv2 not installed * Better import * Fix image read * Use opencv :) * Don't download the file * Fix errors * Use same version * Set higher confidence * Why is the confidence so low? * Start over * Remove stateless layers * Remove extra lines * Revert changes * Save a few more lines --- examples/yolo/yolo_nn.py | 58 ------- examples/yolov3.py | 335 ++++++++------------------------------- setup.py | 1 + test/test_yolo.py | 36 +++++ 4 files changed, 103 insertions(+), 327 deletions(-) delete mode 100644 examples/yolo/yolo_nn.py create mode 100644 test/test_yolo.py diff --git a/examples/yolo/yolo_nn.py b/examples/yolo/yolo_nn.py deleted file mode 100644 index 12dc5d24c..000000000 --- a/examples/yolo/yolo_nn.py +++ /dev/null @@ -1,58 +0,0 @@ -from tinygrad.tensor import Tensor - -# PyTorch style layers for tinygrad. These layers are here because of tinygrads -# line limit. - -class MaxPool2d: - def __init__(self, kernel_size, stride): - if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) - else: self.kernel_size = kernel_size - self.stride = stride if (stride is not None) else kernel_size - - def __repr__(self): - return f"MaxPool2d(kernel_size={self.kernel_size!r}, stride={self.stride!r})" - - def __call__(self, input): - # TODO: Implement strided max_pool2d, and maxpool2d for 3d inputs - return input.max_pool2d(kernel_size=self.kernel_size) - - -class DetectionLayer: - def __init__(self, anchors): - self.anchors = anchors - - def __call__(self, input): - return input - -class EmptyLayer: - def __init__(self): - pass - - def __call__(self, input): - return input - -class Upsample: - def __init__(self, scale_factor = 2, mode = "nearest"): - self.scale_factor, self.mode = scale_factor, mode - - def upsampleNearest(self, input): - # TODO: Implement actual interpolation function - # inspired: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/functional/upsampling.h - return input.cpu().numpy().repeat(self.scale_factor, axis=len(input.shape)-2).repeat(self.scale_factor, axis=len(input.shape)-1) - - def __repr__(self): - return f"Upsample(scale_factor={self.scale_factor!r}, mode={self.mode!r})" - - def __call__(self, input): - return Tensor(self.upsampleNearest(input)) - -class LeakyReLU: - def __init__(self, neg_slope): - self.neg_slope = neg_slope - - def __repr__(self): - return f"LeakyReLU({self.neg_slope!r})" - - def __call__(self, input): - return input.leakyrelu(self.neg_slope) - diff --git a/examples/yolov3.py b/examples/yolov3.py index 9d25298c4..534ba1041 100755 --- a/examples/yolov3.py +++ b/examples/yolov3.py @@ -1,66 +1,39 @@ # https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg -# running import sys import io import time +import math import cv2 import numpy as np from PIL import Image from tinygrad.tensor import Tensor -from tinygrad.nn import BatchNorm2d, Conv2d, optim -from tinygrad.helpers import getenv +from tinygrad.nn import BatchNorm2d, Conv2d from extra.utils import fetch -from examples.yolo.yolo_nn import Upsample, EmptyLayer, DetectionLayer, LeakyReLU, MaxPool2d -np.set_printoptions(suppress=True) -GPU = getenv("GPU") -def show_labels(prediction, confidence = 0.5, num_classes = 80): +def show_labels(prediction, confidence=0.5, num_classes=80): coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names') coco_labels = coco_labels.decode('utf-8').split('\n') - prediction = prediction.detach().cpu().numpy() - conf_mask = (prediction[:,:,4] > confidence) - conf_mask = np.expand_dims(conf_mask, 2) - prediction = prediction * conf_mask - - def numpy_max(input, dim): - # Input -> tensor (10x8) - return np.amax(input, axis=dim), np.argmax(input, axis=dim) - + prediction *= np.expand_dims(conf_mask, 2) + labels = [] # Iterate over batches - for i in range(prediction.shape[0]): - img_pred = prediction[i] - max_conf, max_conf_score = numpy_max(img_pred[:,5:5 + num_classes], 1) + for img_pred in prediction: + max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1) + max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1) max_conf_score = np.expand_dims(max_conf_score, axis=1) max_conf = np.expand_dims(max_conf, axis=1) seq = (img_pred[:,:5], max_conf, max_conf_score) image_pred = np.concatenate(seq, axis=1) - non_zero_ind = np.nonzero(image_pred[:,4])[0] assert all(image_pred[non_zero_ind,0] > 0) - image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - try: - image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - except: - print("No detections found!") - pass classes, indexes = np.unique(image_pred_[:, -1], return_index=True) for index, coco_class in enumerate(classes): - probability = image_pred_[indexes[index]][4] * 100 - print("Detected", coco_labels[int(coco_class)], "{:.2f}%".format(probability)) - -def letterbox_image(img, inp_dim=608): - img_w, img_h = img.shape[1], img.shape[0] - w, h = inp_dim - new_w = int(img_w * min(w/img_w, h/img_h)) - new_h = int(img_h * min(w/img_w, h/img_h)) - resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC) - - canvas = np.full((inp_dim[1], inp_dim[0], 3), 128) - canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w, :] = resized_image - return canvas + label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100 + print(f"Detected {label} {probability:.2f}") + labels.append(label) + return labels def add_boxes(img, prediction): if isinstance(prediction, int): # no predictions @@ -69,12 +42,9 @@ def add_boxes(img, prediction): coco_labels = coco_labels.decode('utf-8').split('\n') height, width = img.shape[0:2] scale_factor = 608 / width - prediction[:,[1,3]] -= (608 - scale_factor * width) / 2 prediction[:,[2,4]] -= (608 - scale_factor * height) / 2 - - for i in range(prediction.shape[0]): - pred = prediction[i] + for pred in prediction: corner1 = tuple(pred[1:3].astype(int)) corner2 = tuple(pred[3:5].astype(int)) w = corner2[0] - corner1[0] @@ -93,37 +63,29 @@ def bbox_iou(box1, box2): Returns the IoU of two bounding boxes IoU: IoU = Area Of Overlap / Area of Union -> How close the predicted bounding box is to the ground truth bounding box. Higher IoU = Better accuracy - In training, used to track accuracy. with inference, using to remove duplicate bounding boxes """ # Get the coordinates of bounding boxes b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] - # get the coordinates of the intersection rectangle inter_rect_x1 = np.maximum(b1_x1, b2_x1) inter_rect_y1 = np.maximum(b1_y1, b2_y1) inter_rect_x2 = np.maximum(b1_x2, b2_x2) inter_rect_y2 = np.maximum(b1_y2, b2_y2) - #Intersection area inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999) - #Union Area b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) - iou = inter_area / (b1_area + b2_area - inter_area) - return iou - -def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0.4): +def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4): prediction = prediction.detach().cpu().numpy() conf_mask = (prediction[:,:,4] > confidence) conf_mask = np.expand_dims(conf_mask, 2) prediction = prediction * conf_mask - # Non max suppression box_corner = prediction box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) @@ -131,109 +93,56 @@ def process_results(prediction, confidence = 0.9, num_classes = 80, nms_conf = 0 box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) prediction[:,:,:4] = box_corner[:,:,:4] - - batch_size = prediction.shape[0] write = False - # Process img img_pred = prediction[0] - - def numpy_max(input, dim): - # Input -> tensor (10x8) - return np.amax(input, axis=dim), np.argmax(input, axis=dim) - - max_conf, max_conf_score = numpy_max(img_pred[:,5:5 + num_classes], 1) + max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1) + max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1) max_conf_score = np.expand_dims(max_conf_score, axis=1) max_conf = np.expand_dims(max_conf, axis=1) seq = (img_pred[:,:5], max_conf, max_conf_score) image_pred = np.concatenate(seq, axis=1) - non_zero_ind = np.nonzero(image_pred[:,4])[0] assert all(image_pred[non_zero_ind,0] > 0) image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - try: - image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - except: - print("No detections found!") - return 0 - if image_pred_.shape[0] == 0: print("No detections found!") return 0 - - def unique(tensor): - tensor_np = tensor - unique_np = np.unique(tensor_np) - return unique_np - - img_classes = unique(image_pred_[:, -1]) - - for cls in img_classes: + for cls in np.unique(image_pred_[:, -1]): # perform NMS, get the detections with one particular class cls_mask = image_pred_*np.expand_dims(image_pred_[:, -1] == cls, axis=1) class_mask_ind = np.squeeze(np.nonzero(cls_mask[:,-2])) # class_mask_ind = np.nonzero() image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7)) - # sort the detections such that the entry with the maximum objectness # confidence is at the top conf_sort_index = np.argsort(image_pred_class[:,4]) image_pred_class = image_pred_class[conf_sort_index] - idx = image_pred_class.shape[0] #Number of detections - - for i in range(idx): + for i in range(image_pred_class.shape[0]): # Get the IOUs of all boxes that come after the one we are looking at in the loop try: ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:]) - except ValueError: + except: break - - except IndexError: - break - # Zero out all the detections that have IoU > threshold iou_mask = np.expand_dims((ious < nms_conf), axis=1) image_pred_class[i+1:] *= iou_mask - # Remove the non-zero entries non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:,4])) image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7)) - batch_ind = np.array([[0]]) seq = (batch_ind, image_pred_class) - if not write: - output = np.concatenate(seq, 1) - write = True + output, write = np.concatenate(seq, axis=1), True else: out = np.concatenate(seq, axis=1) output = np.concatenate((output,out)) - try: - return output - except: - return 0 - -def imresize(img, w, h): - return np.array(Image.fromarray(img).resize((w, h))) - -def resize(img, inp_dim=(608, 608)): - img_w, img_h = img.shape[1], img.shape[0] - w, h = inp_dim - new_w = int(img_w * min(w/img_w, h/img_h)) - new_h = int(img_h * min(w/img_w, h/img_h)) - resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC) - - canvas = np.full((inp_dim[1], inp_dim[0], 3), 128) - canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w, :] = resized_image - return canvas + return output def infer(model, img): - img = np.array(img) - img = imresize(img, 608, 608) - # img = resize(img) + img = np.array(Image.fromarray(img).resize((608, 608))) img = img[:,:,::-1].transpose((2,0,1)) img = img[np.newaxis,:,:,:]/255.0 - prediction = model.forward(Tensor(img.astype(np.float32))) return prediction @@ -244,10 +153,7 @@ def parse_cfg(cfg): lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if x[0] != '#'] lines = [x.rstrip().lstrip() for x in lines] - - block = {} - blocks = [] - + block, blocks = {}, [] for line in lines: if line[0] == "[": if len(block) != 0: @@ -258,7 +164,6 @@ def parse_cfg(cfg): key,value = line.split("=") block[key.rstrip()] = value.lstrip() blocks.append(block) - return blocks # TODO: Speed up this function, avoid copying stuff from GPU to CPU @@ -268,47 +173,29 @@ def predict_transform(prediction, inp_dim, anchors, num_classes): grid_size = inp_dim // stride bbox_attrs = 5 + num_classes num_anchors = len(anchors) - prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size)) # Original PyTorch: transpose(1, 2) -> For some reason numpy.transpose order has to be reversed? - prediction = prediction.transpose(order=(0, 2, 1)) + prediction = prediction.transpose(order=(0,2,1)) prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs)) - - # st = time.time() prediction_cpu = prediction.cpu().numpy() - # print('put on CPU in %.2f s' % (time.time() - st)) - - anchors = [(a[0]/stride, a[1]/stride) for a in anchors] - #Sigmoid the centre_X, centre_Y. and object confidence - # TODO: Fix this - def dsigmoid(data): - return 1/(1+np.exp(-data)) - - prediction_cpu[:,:,0] = dsigmoid(prediction_cpu[:,:,0]) - prediction_cpu[:,:,1] = dsigmoid(prediction_cpu[:,:,1]) - prediction_cpu[:,:,4] = dsigmoid(prediction_cpu[:,:,4]) - + for i in (0, 1, 4): + prediction_cpu[:,:,i] = 1 / (1 + np.exp(-prediction_cpu[:,:,i])) # Add the center offsets grid = np.arange(grid_size) a, b = np.meshgrid(grid, grid) - x_offset = a.reshape((-1, 1)) y_offset = b.reshape((-1, 1)) - x_y_offset = np.concatenate((x_offset, y_offset), 1) x_y_offset = np.tile(x_y_offset, (1, num_anchors)) x_y_offset = x_y_offset.reshape((-1,2)) x_y_offset = np.expand_dims(x_y_offset, 0) - - prediction_cpu[:,:,:2] += x_y_offset - + anchors = [(a[0]/stride, a[1]/stride) for a in anchors] anchors = np.tile(anchors, (grid_size*grid_size, 1)) anchors = np.expand_dims(anchors, 0) - + prediction_cpu[:,:,:2] += x_y_offset prediction_cpu[:,:,2:4] = np.exp(prediction_cpu[:,:,2:4])*anchors - prediction_cpu[:,:,5: 5 + num_classes] = dsigmoid((prediction_cpu[:,:, 5 : 5 + num_classes])) + prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes])) prediction_cpu[:,:,:4] *= stride - return Tensor(prediction_cpu) @@ -320,54 +207,34 @@ class Darknet: def create_modules(self, blocks): net_info = blocks[0] # Info about model hyperparameters - prev_filters = 3 - filters = None - output_filters = [] - module_list = [] + prev_filters, filters = 3, None + output_filters, module_list = [], [] ## module for index, x in enumerate(blocks[1:]): module_type = x["type"] module = [] if module_type == "convolutional": try: - batch_normalize = int(x["batch_normalize"]) - bias = False + batch_normalize, bias = int(x["batch_normalize"]), False except: - batch_normalize = 0 - bias = True - + batch_normalize, bias = 0, True # layer activation = x["activation"] filters = int(x["filters"]) padding = int(x["pad"]) - if padding: - pad = (int(x["size"]) - 1) // 2 - else: - pad = 0 - - conv = Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias = bias) - module.append(conv) - + pad = (int(x["size"]) - 1) // 2 if padding else 0 + module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias)) # BatchNorm2d if batch_normalize: - bn = BatchNorm2d(filters, eps=1e-05, track_running_stats=True) - module.append(bn) - + module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True)) # LeakyReLU activation if activation == "leaky": - module.append(LeakyReLU(0.1)) - - # TODO: Add tiny model + module.append(lambda x: x.leakyrelu(0.1)) elif module_type == "maxpool": - size = int(x["size"]) - stride = int(x["stride"]) - maxpool = MaxPool2d(size, stride) - module.append(maxpool) - + size, stride = int(x["size"]), int(x["stride"]) + module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)) elif module_type == "upsample": - upsample = Upsample(scale_factor = 2, mode = "nearest") - module.append(upsample) - + module.append(lambda x: Tensor(x.cpu().numpy().repeat(2, axis=-2).repeat(2, axis=-1))) elif module_type == "route": x["layers"] = x["layers"].split(",") # Start of route @@ -377,37 +244,26 @@ class Darknet: end = int(x["layers"][1]) except: end = 0 - if start > 0: start = start - index - if end > 0: end = end - index - route = EmptyLayer() - module.append(route) + if start > 0: start -= index + if end > 0: end -= index + module.append(lambda x: x) if end < 0: filters = output_filters[index + start] + output_filters[index + end] else: filters = output_filters[index + start] - # Shortcut corresponds to skip connection elif module_type == "shortcut": - module.append(EmptyLayer()) - + module.append(lambda x: x) elif module_type == "yolo": - mask = x["mask"].split(",") - mask = [int(x) for x in mask] - - anchors = x["anchors"].split(",") - anchors = [int(a) for a in anchors] + mask = list(map(int, x["mask"].split(","))) + anchors = [int(a) for a in x["anchors"].split(",")] anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)] - anchors = [anchors[i] for i in mask] - - detection = DetectionLayer(anchors) - module.append(detection) - + module.append([anchors[i] for i in mask]) # Append to module_list module_list.append(module) if filters is not None: prev_filters = filters output_filters.append(filters) - return (net_info, module_list) def dump_weights(self): @@ -426,56 +282,35 @@ class Darknet: print("None biases for layer", i) def load_weights(self, url): - weights = fetch(url) - # First 5 values (major, minor, subversion, Images seen) - header = np.frombuffer(weights, dtype=np.int32, count = 5) - self.seen = header[3] - - def numel(tensor): - from functools import reduce - return reduce(lambda x, y: x*y, tensor.shape) - - weights = np.frombuffer(weights, dtype=np.float32) - weights = weights[5:] - + weights = np.frombuffer(fetch(url), dtype=np.float32)[5:] ptr = 0 for i in range(len(self.module_list)): module_type = self.blocks[i + 1]["type"] - if module_type == "convolutional": model = self.module_list[i] try: # we have batchnorm, load conv weights without biases, and batchnorm values - batch_normalize = int(self.blocks[i + 1]["batch_normalize"]) + batch_normalize = int(self.blocks[i+1]["batch_normalize"]) except: # no batchnorm, load conv weights + biases batch_normalize = 0 - conv = model[0] - if batch_normalize: bn = model[1] - # Get the number of weights of batchnorm - num_bn_biases = numel(bn.bias) - + num_bn_biases = math.prod(bn.bias.shape) # Load weights bn_biases = Tensor(weights[ptr:ptr + num_bn_biases]) ptr += num_bn_biases - bn_weights = Tensor(weights[ptr:ptr+num_bn_biases]) ptr += num_bn_biases - bn_running_mean = Tensor(weights[ptr:ptr+num_bn_biases]) ptr += num_bn_biases - bn_running_var = Tensor(weights[ptr:ptr+num_bn_biases]) ptr += num_bn_biases - # Cast the loaded weights into dims of model weights bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape)) bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape)) bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape)) bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape)) - # Copy data bn.bias = bn_biases bn.weight = bn_weights @@ -483,32 +318,25 @@ class Darknet: bn.running_var = bn_running_var else: # load biases of the conv layer - num_biases = numel(conv.bias) - + num_biases = math.prod(conv.bias.shape) # Load weights conv_biases = Tensor(weights[ptr: ptr+num_biases]) ptr += num_biases - # Reshape conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape)) - # Copy conv.bias = conv_biases - # Load weighys for conv layers - num_weights = numel(conv.weight) - + num_weights = math.prod(conv.weight.shape) conv_weights = Tensor(weights[ptr:ptr+num_weights]) ptr += num_weights - conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape)) conv.weight = conv_weights def forward(self, x): modules = self.blocks[1:] outputs = {} # Cached outputs for route layer - write = 0 - + detections, write = None, False for i, module in enumerate(modules): module_type = (module["type"]) if module_type == "convolutional" or module_type == "upsample": @@ -525,53 +353,30 @@ class Darknet: if (layers[1]) > 0: layers[1] = layers[1] - i map1 = outputs[i + layers[0]] map2 = outputs[i + layers[1]] - x = Tensor(np.concatenate((map1.cpu().numpy(), map2.cpu().numpy()), 1)) + x = Tensor(np.concatenate((map1.cpu().numpy(), map2.cpu().numpy()), axis=1)) elif module_type == "shortcut": from_ = int(module["from"]) x = outputs[i - 1] + outputs[i + from_] elif module_type == "yolo": - anchors = self.module_list[i][0].anchors - inp_dim = int(self.net_info["height"]) - # inp_dim = 416 + anchors = self.module_list[i][0] + inp_dim = int(self.net_info["height"]) # 416 num_classes = int(module["classes"]) - # Transform x = predict_transform(x, inp_dim, anchors, num_classes) if not write: - detections = x - write = 1 + detections, write = x, True else: - detections = Tensor(np.concatenate((detections.cpu().numpy(), x.cpu().numpy()), 1)) - - # print(module_type, 'layer took %.2f s' % (time.time() - st)) + detections = Tensor(np.concatenate((detections.cpu().numpy(), x.cpu().numpy()), axis=1)) outputs[i] = x - - return detections # Return detections + return detections if __name__ == "__main__": - cfg = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg') # normal model - # cfg = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3-tiny.cfg') # tiny model - - # Make deterministic - np.random.seed(1337) - - # Start model - model = Darknet(cfg) - + model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg')) print("Loading weights file (237MB). This might take a while…") - model.load_weights('https://pjreddie.com/media/files/yolov3.weights') # normal model - # model.load_weights('https://pjreddie.com/media/files/yolov3-tiny.weights') # tiny model - - if GPU: - params = optim.get_parameters(model) - [x.gpu_() for x in params] - + model.load_weights('https://pjreddie.com/media/files/yolov3.weights') if len(sys.argv) > 1: url = sys.argv[1] else: url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png" - - img = None - # We use cv2 because for some reason, cv2 imread produces better results? if url == 'webcam': cap = cv2.VideoCapture(0) cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) @@ -579,11 +384,8 @@ if __name__ == "__main__": _ = cap.grab() # discard one frame to circumvent capture buffering ret, frame = cap.read() img = Image.fromarray(frame[:, :, [2,1,0]]) - - prediction = infer(model, img) - prediction = process_results(prediction) - - boxes = add_boxes(imresize(np.array(img), 608, 608), prediction) + prediction = process_results(infer(model, img)) + boxes = add_boxes(np.array(img.resize((608, 608))), prediction) boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR) cv2.imshow('yolo', boxes) if cv2.waitKey(1) & 0xFF == ord('q'): @@ -595,16 +397,11 @@ if __name__ == "__main__": img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1) else: img = cv2.imread(url) - - # Predict st = time.time() print('running inference…') prediction = infer(model, img) - print(f'did inference in {(time.time() - st):2f} s') - - labels = show_labels(prediction) + print(f'did inference in {(time.time() - st):2f}s') + show_labels(prediction) prediction = process_results(prediction) - # print(prediction) - boxes = add_boxes(imresize(img, 608, 608), prediction) - # Save img + boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction) cv2.imwrite('boxes.jpg', boxes) diff --git a/setup.py b/setup.py index 969d7638a..56d7bf689 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ setup(name='tinygrad', "pytest-xdist", "onnx~=1.13.0", "onnx2torch", + "opencv-python", ], }, include_package_data=True) diff --git a/test/test_yolo.py b/test/test_yolo.py new file mode 100644 index 000000000..3631cabaa --- /dev/null +++ b/test/test_yolo.py @@ -0,0 +1,36 @@ +import io +import unittest +from pathlib import Path + +import cv2 +import requests # type: ignore +import numpy as np + +from tinygrad.tensor import Tensor +from examples.yolov3 import Darknet, infer, show_labels +from extra.utils import fetch + +chicken_img = cv2.imread(str(Path(__file__).parent / 'efficientnet/Chicken.jpg')) +car_img = cv2.imread(str(Path(__file__).parent / 'efficientnet/car.jpg')) + +class TestYOLO(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = Darknet(fetch("https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg")) + print("Loading weights file (237MB). This might take a while…") + cls.model.load_weights("https://pjreddie.com/media/files/yolov3.weights") + + @classmethod + def tearDownClass(cls): + del cls.model + + def test_chicken(self): + labels = show_labels(infer(self.model, chicken_img), confidence=0.56) + self.assertEqual(labels, ["bird"]) + + def test_car(self): + labels = show_labels(infer(self.model, car_img)) + self.assertEqual(labels, ["car"]) + +if __name__ == '__main__': + unittest.main()