remove numpy from pooler
parent
70be81aebc
commit
98f2b1fa2e
|
@ -828,7 +828,6 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
|||
bin_size_w = roi_width / pooled_width
|
||||
|
||||
exact_sampling = sampling_ratio > 0
|
||||
|
||||
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
||||
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
||||
|
||||
|
@ -907,7 +906,6 @@ class LevelMapper:
|
|||
self.eps = eps
|
||||
|
||||
def __call__(self, boxlists):
|
||||
# TODO: remove numpy
|
||||
s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
|
||||
target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
|
||||
target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
|
||||
|
@ -954,21 +952,18 @@ class Pooler:
|
|||
return self.poolers[0](x[0], rois)
|
||||
|
||||
levels = self.map_levels(boxes)
|
||||
|
||||
num_rois = rois.shape[0]
|
||||
num_channels = x[0].shape[1]
|
||||
output_size = self.output_size[0]
|
||||
|
||||
result = np.zeros(
|
||||
(num_rois, num_channels, output_size, output_size), dtype=x[0].dtype.np
|
||||
)
|
||||
results = []
|
||||
all_idxs = []
|
||||
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
|
||||
# this is fine because no grad will flow through index
|
||||
idx_in_level = [idx for idx, x in enumerate((levels.numpy() == level)) if x != 0]
|
||||
if len(idx_in_level) > 0:
|
||||
rois_per_level = tensor_gather(rois, idx_in_level)
|
||||
result[idx_in_level] = pooler(per_level_feature, rois_per_level).numpy()
|
||||
pooler_output = pooler(per_level_feature, rois_per_level)
|
||||
all_idxs.extend(idx_in_level)
|
||||
results.append(pooler_output)
|
||||
|
||||
return Tensor(result, dtype=x[0].dtype, device=x[0].device)
|
||||
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
|
||||
|
||||
|
||||
class FPNPredictor:
|
||||
|
|
|
@ -515,7 +515,7 @@ class Tensor:
|
|||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
def reciprocal(self): return 1.0/self
|
||||
def floor(self): i = self.cast(dtypes.int32).realize(); cond=i > self; return cond * (i - 1) + (1.0 - cond) * i
|
||||
def floor(self): i = self.cast(dtypes.int32); return (self>0).where(i, i-1)
|
||||
def ceil(self): return -1 * (-1 * self).floor()
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
|
Loading…
Reference in New Issue