1
0
Fork 0

remove numpy from pooler

pull/1072/head
Kunwar Raj Singh 2023-06-10 23:56:21 +05:30
parent 70be81aebc
commit 98f2b1fa2e
2 changed files with 8 additions and 13 deletions

View File

@ -828,7 +828,6 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
bin_size_w = roi_width / pooled_width
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
@ -907,7 +906,6 @@ class LevelMapper:
self.eps = eps
def __call__(self, boxlists):
# TODO: remove numpy
s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
@ -954,21 +952,18 @@ class Pooler:
return self.poolers[0](x[0], rois)
levels = self.map_levels(boxes)
num_rois = rois.shape[0]
num_channels = x[0].shape[1]
output_size = self.output_size[0]
result = np.zeros(
(num_rois, num_channels, output_size, output_size), dtype=x[0].dtype.np
)
results = []
all_idxs = []
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
# this is fine because no grad will flow through index
idx_in_level = [idx for idx, x in enumerate((levels.numpy() == level)) if x != 0]
if len(idx_in_level) > 0:
rois_per_level = tensor_gather(rois, idx_in_level)
result[idx_in_level] = pooler(per_level_feature, rois_per_level).numpy()
pooler_output = pooler(per_level_feature, rois_per_level)
all_idxs.extend(idx_in_level)
results.append(pooler_output)
return Tensor(result, dtype=x[0].dtype, device=x[0].device)
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
class FPNPredictor:

View File

@ -515,7 +515,7 @@ class Tensor:
def abs(self): return self.relu() + (-self).relu()
def sign(self): return self / (self.abs() + 1e-10)
def reciprocal(self): return 1.0/self
def floor(self): i = self.cast(dtypes.int32).realize(); cond=i > self; return cond * (i - 1) + (1.0 - cond) * i
def floor(self): i = self.cast(dtypes.int32); return (self>0).where(i, i-1)
def ceil(self): return -1 * (-1 * self).floor()
# ***** activation functions (unary) *****