1
0
Fork 0

fix enet bugs, now is mousetrap

pull/40/head
George Hotz 2020-10-31 10:28:07 -07:00
parent a852143572
commit 68cba88e8f
1 changed files with 25 additions and 7 deletions

View File

@ -3,6 +3,7 @@
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import sys
import io
import numpy as np
np.set_printoptions(suppress=True)
@ -37,7 +38,8 @@ class MBConvBlock:
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
self._bn2 = BatchNorm2D(output_filters)
def __call__(self, x):
def __call__(self, inputs):
x = inputs
if self._expand_conv:
x = swish(self._bn0(x.conv2d(self._expand_conv)))
x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad))
@ -51,6 +53,8 @@ class MBConvBlock:
x = x.mul(x_squeezed.sigmoid())
x = self._bn2(x.conv2d(self._project_conv))
if x.shape == inputs.shape:
x = x.add(inputs)
return swish(x)
class EfficientNet:
@ -82,13 +86,14 @@ class EfficientNet:
def forward(self, x):
x = x.pad2d(padding=(0,1,0,1))
x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2)))
for b in self._blocks:
for block in self._blocks:
print(x.shape)
x = b(x)
x = block(x)
x = swish(self._bn1(x.conv2d(self._conv_head)))
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
x = x.avg_pool2d(kernel_size=x.shape[2:4])
x = x.reshape(shape=(-1, 1280))
#x = x.dropout(0.2)
return swish(x.dot(self._fc).add(self._fc_bias))
return x.dot(self._fc).add(self._fc_bias)
def load_weights_from_torch(self):
# load b0
@ -116,9 +121,13 @@ if __name__ == "__main__":
model = EfficientNet()
model.load_weights_from_torch()
# load cat image and preprocess
# load image and preprocess
from PIL import Image
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
if len(sys.argv) > 1:
url = sys.argv[1]
else:
url = "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"
img = Image.open(io.BytesIO(fetch(url)))
img = img.resize((398, 224))
img = np.array(img)
img = img[:, 87:-87]
@ -143,6 +152,15 @@ if __name__ == "__main__":
import time
st = time.time()
out = model.forward(Tensor(img))
# if you want to look at the outputs
"""
import matplotlib.pyplot as plt
plt.plot(out.data[0])
plt.show()
"""
print("did inference in %.2f s" % (time.time()-st))
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
#print("NOT", np.argmin(out.data), np.min(out.data), lbls[np.argmin(out.data)])