1
0
Fork 0

don't track_running_stats, parameters must require_grad

pull/353/head
George Hotz 2022-07-02 14:38:45 -07:00
parent 07b438aa8b
commit 8cf1aed0f4
4 changed files with 11 additions and 10 deletions

View File

@ -27,7 +27,7 @@ BACKWARD = int(os.getenv("BACKWARD", 0))
if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
model = EfficientNet(NUM, classes=1000, has_se=False)
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
parameters = get_parameters(model)
optimizer = optim.Adam(parameters, lr=0.001)

View File

@ -22,7 +22,7 @@ def fetch(url):
def get_parameters(obj):
parameters = []
if isinstance(obj, Tensor):
parameters.append(obj)
if obj.requires_grad: parameters.append(obj)
elif isinstance(obj, list) or isinstance(obj, tuple):
for x in obj:
parameters.extend(get_parameters(x))

View File

@ -5,11 +5,11 @@ from tinygrad.nn import BatchNorm2D
from extra.utils import fetch, fake_torch_load, get_child
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
oup = expand_ratio * input_filters
if expand_ratio != 1:
self._expand_conv = Tensor.uniform(oup, input_filters, 1, 1)
self._bn0 = BatchNorm2D(oup)
self._bn0 = BatchNorm2D(oup, track_running_stats=track_running_stats)
else:
self._expand_conv = None
@ -20,7 +20,7 @@ class MBConvBlock:
self.pad = [(kernel_size-1)//2]*4
self._depthwise_conv = Tensor.uniform(oup, 1, kernel_size, kernel_size)
self._bn1 = BatchNorm2D(oup)
self._bn1 = BatchNorm2D(oup, track_running_stats=track_running_stats)
self.has_se = has_se
if self.has_se:
@ -31,7 +31,7 @@ class MBConvBlock:
self._se_expand_bias = Tensor.zeros(oup)
self._project_conv = Tensor.uniform(output_filters, oup, 1, 1)
self._bn2 = BatchNorm2D(output_filters)
self._bn2 = BatchNorm2D(output_filters, track_running_stats=track_running_stats)
def __call__(self, inputs):
x = inputs
@ -52,7 +52,7 @@ class MBConvBlock:
return x
class EfficientNet:
def __init__(self, number=0, classes=1000, has_se=True):
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True):
self.number = number
global_params = [
# width, depth
@ -82,7 +82,7 @@ class EfficientNet:
out_channels = round_filters(32)
self._conv_stem = Tensor.uniform(out_channels, 3, 3, 3)
self._bn0 = BatchNorm2D(out_channels)
self._bn0 = BatchNorm2D(out_channels, track_running_stats=track_running_stats)
blocks_args = [
[1, 3, (1,1), 1, 32, 16, 0.25],
[2, 3, (2,2), 6, 16, 24, 0.25],
@ -97,14 +97,14 @@ class EfficientNet:
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
for n in range(round_repeats(num_repeats)):
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se))
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
input_filters = output_filters
strides = (1,1)
in_channels = round_filters(320)
out_channels = round_filters(1280)
self._conv_head = Tensor.uniform(out_channels, in_channels, 1, 1)
self._bn1 = BatchNorm2D(out_channels)
self._bn1 = BatchNorm2D(out_channels, track_running_stats=track_running_stats)
self._fc = Tensor.uniform(out_channels, classes)
self._fc_bias = Tensor.zeros(classes)

View File

@ -22,6 +22,7 @@ class BatchNorm2D:
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var