import re
import math
import os
import numpy as np
from pathlib import Path
from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from extra.utils import get_child, download_file
from tinygrad.state import torch_load
from models.resnet import ResNet
from models.retinanet import nms as _box_nms
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
def rint(tensor):
x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
return (x<0).where(x.floor(), x.ceil())
def nearest_interpolate(tensor, scale_factor):
bs, c, py, px = tensor.shape
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
def meshgrid(x, y):
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
def topk(input_, k, dim=-1, largest=True, sorted=False):
k = min(k, input_.shape[dim]-1)
input_ = input_.numpy()
if largest: input_ *= -1
ind = np.argpartition(input_, k, axis=dim)
if largest: input_ *= -1
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
if not sorted: return Tensor(input_), ind
if largest: input_ *= -1
ind_part = np.argsort(input_, axis=dim)
ind = np.take_along_axis(ind, ind_part, axis=dim)
if largest: input_ *= -1
val = np.take_along_axis(input_, ind_part, axis=dim)
return Tensor(val), ind
# This is very slow for large arrays, or indices
def _gather(array, indices):
indices = indices.float().to(array.device)
reshape_arg = [1]*array.ndim + [array.shape[-1]]
return Tensor.where(
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
array, 0,
# TODO: replace npgather with a faster gather using tinygrad only
# NOTE: this blocks the gradient
def npgather(array,indices):
if isinstance(array, Tensor): array = array.numpy()
if isinstance(indices, Tensor): indices = indices.numpy()
if isinstance(indices, list): indices = np.asarray(indices)
return Tensor(array[indices.astype(int)])
def get_strides(shape):
prod = [1]
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
# something about ints is broken with gpu, cuda
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu()
# with keys as integer array for all axes
def tensor_getitem(tensor, *keys):
# something about ints is broken with gpu, cuda
flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32)
strides = get_strides(tensor.shape)
idxs = (flat_keys * strides).sum(1)
gatherer = npgather if USE_NP_GATHER else _gather
return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)
# for gather with indicies only on axis=0
def tensor_gather(tensor, indices):
if not isinstance(indices, Tensor):
indices = Tensor(indices, requires_grad=False)
if len(tensor.shape) > 2:
rem_shape = list(tensor.shape)[1:]
tensor = tensor.reshape(tensor.shape[0], -1)
rem_shape = None
if len(tensor.shape) > 1:
tensor = tensor.T
repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
ret = _gather(tensor, indices)
if rem_shape:
ret = ret.reshape([indices.shape[0]] + rem_shape)
ret = _gather(tensor, indices)
del indices
return ret
class LastLevelMaxPool:
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
# transpose
def permute_and_flatten(layer:Tensor, N, A, C, H, W):
layer = layer.reshape(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
class BoxList:
def __init__(self, bbox, image_size, mode="xyxy"):
if not isinstance(bbox, Tensor):
bbox = Tensor(bbox)
if bbox.ndim != 2:
raise ValueError(
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
if bbox.shape[-1] != 4:
raise ValueError(
"last dimenion of bbox should have a "
"size of 4, got {}".format(bbox.shape[-1])
if mode not in ("xyxy", "xywh"):
raise ValueError("mode should be 'xyxy' or 'xywh'")
self.bbox = bbox
self.size = image_size # (image_width, image_height)
self.mode = mode
self.extra_fields = {}
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_boxes={}, ".format(len(self))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s
def area(self):
box = self.bbox
if self.mode == "xyxy":
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
elif self.mode == "xywh":
area = box[:, 2] * box[:, 3]
return area
def add_field(self, field, field_data):
self.extra_fields[field] = field_data
def get_field(self, field):
return self.extra_fields[field]
def has_field(self, field):
return field in self.extra_fields
def fields(self):
return list(self.extra_fields.keys())
def _copy_extra_fields(self, bbox):
for k, v in bbox.extra_fields.items():
self.extra_fields[k] = v
def convert(self, mode):
if mode == self.mode:
return self
xmin, ymin, xmax, ymax = self._split_into_xyxy()
if mode == "xyxy":
bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
bbox = BoxList(bbox, self.size, mode=mode)
bbox = Tensor.cat(
*(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
bbox = BoxList(bbox, self.size, mode=mode)
return bbox
def _split_into_xyxy(self):
if self.mode == "xyxy":
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
return xmin, ymin, xmax, ymax
if self.mode == "xywh":
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
return (
xmin + (w - TO_REMOVE).clamp(min=0),
ymin + (h - TO_REMOVE).clamp(min=0),
def resize(self, size, *args, **kwargs):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
if ratios[0] == ratios[1]:
ratio = ratios[0]
scaled_box = self.bbox * ratio
bbox = BoxList(scaled_box, size, mode=self.mode)
for k, v in self.extra_fields.items():
if not isinstance(v, Tensor):
v = v.resize(size, *args, **kwargs)
bbox.add_field(k, v)
return bbox
ratio_width, ratio_height = ratios
xmin, ymin, xmax, ymax = self._split_into_xyxy()
scaled_xmin = xmin * ratio_width
scaled_xmax = xmax * ratio_width
scaled_ymin = ymin * ratio_height
scaled_ymax = ymax * ratio_height
scaled_box = Tensor.cat(
*(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
bbox = BoxList(scaled_box, size, mode="xyxy")
for k, v in self.extra_fields.items():
if not isinstance(v, Tensor):
v = v.resize(size, *args, **kwargs)
bbox.add_field(k, v)
return bbox.convert(self.mode)
def transpose(self, method):
image_width, image_height = self.size
xmin, ymin, xmax, ymax = self._split_into_xyxy()
if method == FLIP_LEFT_RIGHT:
transposed_xmin = image_width - xmax - TO_REMOVE
transposed_xmax = image_width - xmin - TO_REMOVE
transposed_ymin = ymin
transposed_ymax = ymax
elif method == FLIP_TOP_BOTTOM:
transposed_xmin = xmin
transposed_xmax = xmax
transposed_ymin = image_height - ymax
transposed_ymax = image_height - ymin
transposed_boxes = Tensor.cat(
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
for k, v in self.extra_fields.items():
if not isinstance(v, Tensor):
v = v.transpose(method)
bbox.add_field(k, v)
return bbox.convert(self.mode)
def clip_to_image(self, remove_empty=True):
bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1)
if remove_empty:
box = self.bbox
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
return self[keep]
return self
def __getitem__(self, item):
if isinstance(item, list):
if len(item) == 0:
return []
if sum(item) == len(item) and isinstance(item[0], bool):
return self
bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
for k, v in self.extra_fields.items():
bbox.add_field(k, tensor_gather(v, item))
return bbox
def __len__(self):
return self.bbox.shape[0]
def cat_boxlist(bboxes):
size = bboxes[0].size
mode = bboxes[0].mode
fields = set(bboxes[0].fields())
cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]
if len(cat_box_list) > 0:
cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
for field in fields:
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
if len(cat_box_list) > 0:
data = Tensor.cat(*cat_field_list, dim=0)
data = bboxes[0].get_field(field)
cat_boxes.add_field(field, data)
return cat_boxes
class FPN:
def __init__(self, in_channels_list, out_channels):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
self.top_block = LastLevelMaxPool()
def __call__(self, x: Tensor):
last_inner = self.inner_blocks[-1](x[-1])
results = []
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
if not inner_block:
inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
inner_lateral = inner_block(feature)
last_inner = inner_lateral + inner_top_down
layer_result = layer_block(last_inner)
results.insert(0, layer_result)
last_results = self.top_block(results[-1])
return tuple(results)
class ResNetFPN:
def __init__(self, resnet, out_channels=256):
self.out_channels = out_channels
self.body = resnet
in_channels_stage2 = 256
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
self.fpn = FPN(in_channels_list, out_channels)
def __call__(self, x):
x = self.body(x)
return self.fpn(x)
class AnchorGenerator:
def __init__(
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.5, 1.0, 2.0),
anchor_strides=(4, 8, 16, 32, 64),
if len(anchor_strides) == 1:
anchor_stride = anchor_strides[0]
cell_anchors = [
generate_anchors(anchor_stride, sizes, aspect_ratios)
if len(anchor_strides) != len(sizes):
raise RuntimeError("FPN should have #anchor_strides == #sizes")
cell_anchors = [
size if isinstance(size, (tuple, list)) else (size,),
for anchor_stride, size in zip(anchor_strides, sizes)
self.strides = anchor_strides
self.cell_anchors = cell_anchors
self.straddle_thresh = straddle_thresh
def num_anchors_per_location(self):
return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]
def grid_anchors(self, grid_sizes):
anchors = []
for size, stride, base_anchors in zip(
grid_sizes, self.strides, self.cell_anchors
grid_height, grid_width = size
device = base_anchors.device
shifts_x = Tensor.arange(
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
shifts_y = Tensor.arange(
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
return anchors
def add_visibility_to(self, boxlist):
image_width, image_height = boxlist.size
anchors = boxlist.bbox
if self.straddle_thresh >= 0:
inds_inside = (
(anchors[:, 0] >= -self.straddle_thresh)
* (anchors[:, 1] >= -self.straddle_thresh)
* (anchors[:, 2] < image_width + self.straddle_thresh)
* (anchors[:, 3] < image_height + self.straddle_thresh)
device = anchors.device
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
boxlist.add_field("visibility", inds_inside)
def __call__(self, image_list, feature_maps):
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
anchors = []
for (image_height, image_width) in image_list.image_sizes:
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
boxlist = BoxList(
anchors_per_feature_map, (image_width, image_height), mode="xyxy"
return anchors
def generate_anchors(
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
def _generate_anchors(base_size, scales, aspect_ratios):
anchor = Tensor([1, 1, base_size, base_size]) - 1
anchors = _ratio_enum(anchor, aspect_ratios)
anchors = Tensor.cat(
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
return anchors
def _whctrs(anchor):
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
ws = ws[:, None]
hs = hs[:, None]
anchors = Tensor.cat(*(
x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1),
), dim=1)
return anchors
def _ratio_enum(anchor, ratios):
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = rint(Tensor.sqrt(size_ratios))
hs = rint(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
class RPNHead:
def __init__(self, in_channels, num_anchors):
self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)
def __call__(self, x):
logits = []
bbox_reg = []
for feature in x:
t = Tensor.relu(self.conv(feature))
return logits, bbox_reg
class BoxCoder(object):
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals):
TO_REMOVE = 1 # TODO remove
ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
wx, wy, ww, wh = self.weights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * Tensor.log(gt_widths / ex_widths)
targets_dh = wh * Tensor.log(gt_heights / ex_heights)
targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
def decode(self, rel_codes, boxes):
boxes = boxes.cast(rel_codes.dtype)
rel_codes = rel_codes
TO_REMOVE = 1 # TODO remove
widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx, wy, ww, wh = self.weights
dx = rel_codes[:, 0::4] / wx
dy = rel_codes[:, 1::4] / wy
dw = rel_codes[:, 2::4] / ww
dh = rel_codes[:, 3::4] / wh
# Prevent sending too large values into Tensor.exp()
dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = dw.exp() * widths[:, None]
pred_h = dh.exp() * heights[:, None]
x = pred_ctr_x - 0.5 * pred_w
y = pred_ctr_y - 0.5 * pred_h
w = pred_ctr_x + 0.5 * pred_w - 1
h = pred_ctr_y + 0.5 * pred_h - 1
pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
return pred_boxes
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
if nms_thresh <= 0:
return boxlist
mode = boxlist.mode
boxlist = boxlist.convert("xyxy")
boxes = boxlist.bbox
score = boxlist.get_field(score_field)
keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
if max_proposals > 0:
keep = keep[:max_proposals]
boxlist = boxlist[keep]
return boxlist.convert(mode)
def remove_small_boxes(boxlist, min_size):
xywh_boxes = boxlist.convert("xywh").bbox
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
keep = ((
(ws >= min_size) * (hs >= min_size)
) > 0).reshape(-1)
if keep.sum().numpy() == len(boxlist):
return boxlist
keep = keep.numpy().nonzero()[0]
return boxlist[keep]
class RPNPostProcessor:
# Not used in Loss calculation
def __init__(
self.pre_nms_top_n = pre_nms_top_n
self.post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = min_size
if box_coder is None:
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
self.box_coder = box_coder
if fpn_post_nms_top_n is None:
fpn_post_nms_top_n = post_nms_top_n
self.fpn_post_nms_top_n = fpn_post_nms_top_n
def forward_for_single_feature_map(self, anchors, objectness, box_regression):
device = objectness.device
N, A, H, W = objectness.shape
objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
objectness = objectness.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
num_anchors = A * H * W
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
image_shapes = [box.size for box in anchors]
box_regression_list = []
concat_anchors_list = []
for batch_idx in range(N):
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
box_regression = Tensor.stack(box_regression_list)
concat_anchors = Tensor.stack(concat_anchors_list)
proposals = self.box_coder.decode(
box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
proposals = proposals.reshape(N, -1, 4)
result = []
for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
boxlist = BoxList(proposal, im_shape, mode="xyxy")
boxlist.add_field("objectness", score)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
boxlist = boxlist_nms(
return result
def __call__(self, anchors, objectness, box_regression):
sampled_boxes = []
num_levels = len(objectness)
anchors = list(zip(*anchors))
for a, o, b in zip(anchors, objectness, box_regression):
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
if num_levels > 1:
boxlists = self.select_over_all_levels(boxlists)
return boxlists
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
for i in range(num_images):
objectness = boxlists[i].get_field("objectness")
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
_, inds_sorted = topk(objectness,
post_nms_top_n, dim=0, sorted=False
boxlists[i] = boxlists[i][inds_sorted]
return boxlists
class RPN:
def __init__(self, in_channels):
self.anchor_generator = AnchorGenerator()
in_channels = 256
head = RPNHead(
in_channels, self.anchor_generator.num_anchors_per_location()[0]
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
box_selector_test = RPNPostProcessor(
self.head = head
self.box_selector_test = box_selector_test
def __call__(self, images, features, targets=None):
objectness, rpn_box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
return boxes, {}
def make_conv3x3(
conv = nn.Conv2d(
bias=False if use_gn else True
return conv
class MaskRCNNFPNFeatureExtractor:
def __init__(self):
resolution = 14
scales = (0.25, 0.125, 0.0625, 0.03125)
sampling_ratio = 2
pooler = Pooler(
output_size=(resolution, resolution),
input_size = 256
self.pooler = pooler
use_gn = False
layers = (256, 256, 256, 256)
dilation = 1
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
def __call__(self, x, proposals):
x = self.pooler(x, proposals)
for layer in self.blocks:
if x is not None:
x = Tensor.relu(layer(x))
return x
class MaskRCNNC4Predictor:
def __init__(self):
num_classes = 81
dim_reduced = 256
num_inputs = dim_reduced
self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)
def __call__(self, x):
x = Tensor.relu(self.conv5_mask(x))
return self.mask_fcn_logits(x)
class FPN2MLPFeatureExtractor:
def __init__(self, cfg):
resolution = 7
scales = (0.25, 0.125, 0.0625, 0.03125)
sampling_ratio = 2
pooler = Pooler(
output_size=(resolution, resolution),
input_size = 256 * resolution ** 2
representation_size = 1024
self.pooler = pooler
self.fc6 = nn.Linear(input_size, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)
def __call__(self, x, proposals):
x = self.pooler(x, proposals)
x = x.reshape(x.shape[0], -1)
x = Tensor.relu(self.fc6(x))
x = Tensor.relu(self.fc7(x))
return x
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
_, channels, height, width = input.shape
y = y.clip(min_=0.0, max_=float(height-1))
x = x.clip(min_=0.0, max_=float(width-1))
# Tensor.where doesnt work well with int32 data so cast to float32
y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
x_low = x.cast(dtypes.int32).contiguous().float().contiguous()
y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)
x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)
ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
if ymask is not None:
assert xmask is not None
y = Tensor.where(ymask[:, None, :], y, 0)
x = Tensor.where(xmask[:, None, :], x, 0)
key1 = roi_batch_ind[:, None, None, None, None, None]
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
key3 = y[:, None, :, None, :, None]
key4 = x[:, None, None, :, None, :]
return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)
val = w1*v1 + w2*v2 + w3*v3 + w4*v4
return val
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
_, _, height, width = input.shape
ph = Tensor.arange(pooled_height, device=input.device)
pw = Tensor.arange(pooled_width, device=input.device)
roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset
roi_start_h = rois[:, 2] * spatial_scale - offset
roi_end_w = rois[:, 3] * spatial_scale - offset
roi_end_h = rois[:, 4] * spatial_scale - offset
roi_width = roi_end_w - roi_start_w
roi_height = roi_end_h - roi_start_h
if not aligned:
roi_width = roi_width.maximum(1.0)
roi_height = roi_height.maximum(1.0)
bin_size_h = roi_height / pooled_height
bin_size_w = roi_width / pooled_width
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
iy = Tensor.arange(roi_bin_grid_h, device=input.device)
ix = Tensor.arange(roi_bin_grid_w, device=input.device)
ymask = None
xmask = None
count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
iy = Tensor.arange(height, device=input.device)
ix = Tensor.arange(width, device=input.device)
ymask = iy[None, :] < roi_bin_grid_h[:, None]
xmask = ix[None, :] < roi_bin_grid_w[:, None]
def from_K(t):
return t[:, None, None]
y = (
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
x = (
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
if not exact_sampling:
val = ymask[:, None, None, None, :, None].where(val, 0)
val = xmask[:, None, None, None, None, :].where(val, 0)
output = val.sum((-1, -2))
if isinstance(count, Tensor):
output /= count[:, None, None, None]
output /= count
output = output.cast(orig_dtype)
return output
class ROIAlign:
def __init__(self, output_size, spatial_scale, sampling_ratio):
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def __call__(self, input, rois):
output = _roi_align(
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
return output
class LevelMapper:
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
self.lvl0 = canonical_level
self.eps = eps
def __call__(self, boxlists):
s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
return target_lvls - self.k_min
class Pooler:
def __init__(self, output_size, scales, sampling_ratio):
self.output_size = output_size
self.scales = scales
self.sampling_ratio = sampling_ratio
poolers = []
for scale in scales:
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
self.poolers = poolers
self.output_size = output_size
lvl_min = -math.log2(scales[0])
lvl_max = -math.log2(scales[-1])
self.map_levels = LevelMapper(lvl_min, lvl_max)
def convert_to_roi_format(self, boxes):
concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = Tensor.cat(
Tensor.full((len(b), 1), i, dtype=dtype, device=device)
for i, b in enumerate(boxes)
if concat_boxes.shape[0] != 0:
rois = Tensor.cat(*[ids, concat_boxes], dim=1)
return rois
def __call__(self, x, boxes):
num_levels = len(self.poolers)
rois = self.convert_to_roi_format(boxes)
if rois:
if num_levels == 1:
return self.poolers[0](x[0], rois)
levels = self.map_levels(boxes)
results = []
all_idxs = []
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
# this is fine because no grad will flow through index
idx_in_level = (levels.numpy() == level).nonzero()[0]
if len(idx_in_level) > 0:
rois_per_level = tensor_gather(rois, idx_in_level)
pooler_output = pooler(per_level_feature, rois_per_level)
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
class FPNPredictor:
def __init__(self):
num_classes = 81
representation_size = 1024
self.cls_score = nn.Linear(representation_size, num_classes)
num_bbox_reg_classes = num_classes
self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
def __call__(self, x):
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
return scores, bbox_deltas
class PostProcessor:
# Not used in training
def __init__(
self.score_thresh = score_thresh
self.nms = nms
self.detections_per_img = detections_per_img
if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
def __call__(self, x, boxes):
class_logits, box_regression = x
class_prob = Tensor.softmax(class_logits, -1)
image_shapes = [box.size for box in boxes]
boxes_per_image = [len(box) for box in boxes]
concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)
if self.cls_agnostic_bbox_reg:
box_regression = box_regression[:, -4:]
proposals = self.box_coder.decode(
box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
if self.cls_agnostic_bbox_reg:
proposals = proposals.repeat([1, class_prob.shape[1]])
num_classes = class_prob.shape[1]
proposals = proposals.unsqueeze(0)
class_prob = class_prob.unsqueeze(0)
results = []
for prob, boxes_per_img, image_shape in zip(
class_prob, proposals, image_shapes
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = self.filter_results(boxlist, num_classes)
return results
def prepare_boxlist(self, boxes, scores, image_shape):
boxes = boxes.reshape(-1, 4)
scores = scores.reshape(-1)
boxlist = BoxList(boxes, image_shape, mode="xyxy")
boxlist.add_field("scores", scores)
return boxlist
def filter_results(self, boxlist, num_classes):
boxes = boxlist.bbox.reshape(-1, num_classes * 4)
scores = boxlist.get_field("scores").reshape(-1, num_classes)
device = scores.device
result = []
scores = scores.numpy()
boxes = boxes.numpy()
inds_all = scores > self.score_thresh
for j in range(1, num_classes):
inds = inds_all[:, j].nonzero()[0]
# This needs to be done in numpy because it can create empty arrays
scores_j = scores[inds, j]
boxes_j = boxes[inds, j * 4: (j + 1) * 4]
boxes_j = Tensor(boxes_j)
scores_j = Tensor(scores_j)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
if len(boxlist_for_class):
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms
num_labels = len(boxlist_for_class)
"labels", Tensor.full((num_labels,), j, device=device)
result = cat_boxlist(result)
number_of_detections = len(result)
if number_of_detections > self.detections_per_img > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
image_thresh = image_thresh.numpy()[-1]
keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
result = result[keep]
return result
class RoIBoxHead:
def __init__(self, in_channels):
self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
self.predictor = FPNPredictor()
self.post_processor = PostProcessor(
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
def __call__(self, features, proposals, targets=None):
x = self.feature_extractor(features, proposals)
class_logits, box_regression = self.predictor(x)
if not Tensor.training:
result = self.post_processor((class_logits, box_regression), proposals)
return x, result, {}
class MaskPostProcessor:
# Not used in loss calculation
def __call__(self, x, boxes):
mask_prob = x.sigmoid().numpy()
num_masks = x.shape[0]
labels = [bbox.get_field("labels") for bbox in boxes]
labels = Tensor.cat(*labels).numpy().astype(np.int32)
index = np.arange(num_masks)
mask_prob = mask_prob[index, labels][:, None]
boxes_per_image, cumsum = [], 0
for box in boxes:
cumsum += len(box)
# using numpy here as Tensor.chunk doesnt have custom chunk sizes
mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
results = []
for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
for field in box.fields():
bbox.add_field(field, box.get_field(field))
prob = Tensor(prob)
bbox.add_field("mask", prob)
return results
class Mask:
def __init__(self):
self.feature_extractor = MaskRCNNFPNFeatureExtractor()
self.predictor = MaskRCNNC4Predictor()
self.post_processor = MaskPostProcessor()
def __call__(self, features, proposals, targets=None):
x = self.feature_extractor(features, proposals)
if x:
mask_logits = self.predictor(x)
if not Tensor.training:
result = self.post_processor(mask_logits, proposals)
return x, result, {}
return x, [], {}
class RoIHeads:
def __init__(self, in_channels):
self.box = RoIBoxHead(in_channels)
self.mask = Mask()
def __call__(self, features, proposals, targets=None):
x, detections, _ = self.box(features, proposals, targets)
x, detections, _ = self.mask(features, detections, targets)
return x, detections, {}
class ImageList(object):
def __init__(self, tensors, image_sizes):
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, *args, **kwargs):
cast_tensor = self.tensors.to(*args, **kwargs)
return ImageList(cast_tensor, self.image_sizes)
def to_image_list(tensors, size_divisible=32):
# Preprocessing
if isinstance(tensors, Tensor) and size_divisible > 0:
tensors = [tensors]
if isinstance(tensors, ImageList):
return tensors
elif isinstance(tensors, Tensor):
# single tensor shape can be inferred
assert tensors.ndim == 4
image_sizes = [tensor.shape[-2:] for tensor in tensors]
return ImageList(tensors, image_sizes)
elif isinstance(tensors, (tuple, list)):
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
if size_divisible > 0:
stride = size_divisible
max_size = list(max_size)
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
max_size = tuple(max_size)
batch_shape = (len(tensors),) + max_size
batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np)
for img, pad_img in zip(tensors, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()
batched_imgs = Tensor(batched_imgs)
image_sizes = [im.shape[-2:] for im in tensors]
return ImageList(batched_imgs, image_sizes)
raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
class MaskRCNN:
def __init__(self, backbone: ResNet):
self.backbone = ResNetFPN(backbone, out_channels=256)
self.rpn = RPN(self.backbone.out_channels)
self.roi_heads = RoIHeads(self.backbone.out_channels)
def load_from_pretrained(self):
fn = Path('./') / "weights/maskrcnn.pt"
download_file("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
state_dict = torch_load(fn)['model']
loaded_keys = []
for k, v in state_dict.items():
if "module." in k:
k = k.replace("module.", "")
if "stem." in k:
k = k.replace("stem.", "")
if "fpn_inner" in k:
block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
if "fpn_layer" in k:
block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
get_child(self, k).assign(v.numpy()).realize()
return loaded_keys
def __call__(self, images):
images = to_image_list(images)
features = self.backbone(images.tensors)
proposals, _ = self.rpn(images, features)
x, result, _ = self.roi_heads(features, proposals)
return result
if __name__ == '__main__':
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model = MaskRCNN(backbone=resnet)