1
0
Fork 0

enet work

pull/28/head
George Hotz 2020-10-27 21:23:02 -07:00
parent 0ec279951f
commit 03d9c98f5b
2 changed files with 46 additions and 4 deletions

View File

@ -10,9 +10,11 @@ def swish(x):
return x.mul(x.sigmoid())
class BatchNorm2D:
def __init__(self, sz):
def __init__(self, sz, eps=0.001):
self.eps = eps
self.weight = Tensor.zeros(sz)
self.bias = Tensor.zeros(sz)
# TODO: need running_mean and running_var
self.running_mean = Tensor.zeros(sz)
self.running_var = Tensor.zeros(sz)
@ -20,7 +22,9 @@ class BatchNorm2D:
def __call__(self, x):
# this work at inference?
x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1]))
x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt())
x = x.add(self.bias.reshape(shape=[1, -1, 1, 1]))
return x
@ -102,6 +106,8 @@ class EfficientNet:
return swish(x.dot(self._fc).add(self._fc_bias))
if __name__ == "__main__":
import numpy as np
np.set_printoptions(suppress=True)
# instantiate my net
model = EfficientNet()
@ -114,7 +120,7 @@ if __name__ == "__main__":
if '_blocks.' in k:
k = "%s[%s].%s" % tuple(k.split(".", 2))
mk = "model."+k
print(k, v.shape)
#print(k, v.shape)
try:
mv = eval(mk)
except AttributeError:
@ -125,6 +131,7 @@ if __name__ == "__main__":
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
#b0 = pickle.loads(b0)
out = model.forward(Tensor.zeros(1, 3, 224, 224))
print(out)
img = np.zeros((1, 3, 224, 224), np.float32) + 0.5
out = model.forward(Tensor(img))
print(out.data[:, 0:10])

View File

@ -24,6 +24,41 @@ class Add(Function):
def backward(ctx, grad_output):
return grad_output, grad_output
register('add', Add)
class Sub(Function):
@staticmethod
def forward(ctx, x, y):
return x-y
@staticmethod
def backward(ctx, grad_output):
# this right?
return grad_output, -grad_output
register('sub', Sub)
class Div(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x/y
@staticmethod
def backward(ctx, grad_output):
# this right?
x,y = ctx.saved_tensors
return y/grad_output, x/grad_output
register('div', Div)
class Sqrt(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return np.sqrt(x)
@staticmethod
def backward(ctx, grad_output):
raise Exception("write this")
register('sqrt', Sqrt)
class Dot(Function):
@staticmethod