don't track_running_stats, parameters must require_grad
parent
07b438aa8b
commit
8cf1aed0f4
|
@ -27,7 +27,7 @@ BACKWARD = int(os.getenv("BACKWARD", 0))
|
|||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False)
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = get_parameters(model)
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ def fetch(url):
|
|||
def get_parameters(obj):
|
||||
parameters = []
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
if obj.requires_grad: parameters.append(obj)
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
|
|
|
@ -5,11 +5,11 @@ from tinygrad.nn import BatchNorm2D
|
|||
from extra.utils import fetch, fake_torch_load, get_child
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
oup = expand_ratio * input_filters
|
||||
if expand_ratio != 1:
|
||||
self._expand_conv = Tensor.uniform(oup, input_filters, 1, 1)
|
||||
self._bn0 = BatchNorm2D(oup)
|
||||
self._bn0 = BatchNorm2D(oup, track_running_stats=track_running_stats)
|
||||
else:
|
||||
self._expand_conv = None
|
||||
|
||||
|
@ -20,7 +20,7 @@ class MBConvBlock:
|
|||
self.pad = [(kernel_size-1)//2]*4
|
||||
|
||||
self._depthwise_conv = Tensor.uniform(oup, 1, kernel_size, kernel_size)
|
||||
self._bn1 = BatchNorm2D(oup)
|
||||
self._bn1 = BatchNorm2D(oup, track_running_stats=track_running_stats)
|
||||
|
||||
self.has_se = has_se
|
||||
if self.has_se:
|
||||
|
@ -31,7 +31,7 @@ class MBConvBlock:
|
|||
self._se_expand_bias = Tensor.zeros(oup)
|
||||
|
||||
self._project_conv = Tensor.uniform(output_filters, oup, 1, 1)
|
||||
self._bn2 = BatchNorm2D(output_filters)
|
||||
self._bn2 = BatchNorm2D(output_filters, track_running_stats=track_running_stats)
|
||||
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
|
@ -52,7 +52,7 @@ class MBConvBlock:
|
|||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0, classes=1000, has_se=True):
|
||||
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
|
@ -82,7 +82,7 @@ class EfficientNet:
|
|||
|
||||
out_channels = round_filters(32)
|
||||
self._conv_stem = Tensor.uniform(out_channels, 3, 3, 3)
|
||||
self._bn0 = BatchNorm2D(out_channels)
|
||||
self._bn0 = BatchNorm2D(out_channels, track_running_stats=track_running_stats)
|
||||
blocks_args = [
|
||||
[1, 3, (1,1), 1, 32, 16, 0.25],
|
||||
[2, 3, (2,2), 6, 16, 24, 0.25],
|
||||
|
@ -97,14 +97,14 @@ class EfficientNet:
|
|||
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
|
||||
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
|
||||
for n in range(round_repeats(num_repeats)):
|
||||
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se))
|
||||
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
|
||||
input_filters = output_filters
|
||||
strides = (1,1)
|
||||
|
||||
in_channels = round_filters(320)
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.uniform(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2D(out_channels)
|
||||
self._bn1 = BatchNorm2D(out_channels, track_running_stats=track_running_stats)
|
||||
self._fc = Tensor.uniform(out_channels, classes)
|
||||
self._fc_bias = Tensor.zeros(classes)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ class BatchNorm2D:
|
|||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
||||
|
|
Loading…
Reference in New Issue