fix enet bugs, now is mousetrap
parent
a852143572
commit
68cba88e8f
|
@ -3,6 +3,7 @@
|
|||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
import sys
|
||||
import io
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
|
@ -37,7 +38,8 @@ class MBConvBlock:
|
|||
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
|
||||
self._bn2 = BatchNorm2D(output_filters)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
if self._expand_conv:
|
||||
x = swish(self._bn0(x.conv2d(self._expand_conv)))
|
||||
x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad))
|
||||
|
@ -51,6 +53,8 @@ class MBConvBlock:
|
|||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
if x.shape == inputs.shape:
|
||||
x = x.add(inputs)
|
||||
return swish(x)
|
||||
|
||||
class EfficientNet:
|
||||
|
@ -82,13 +86,14 @@ class EfficientNet:
|
|||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2)))
|
||||
for b in self._blocks:
|
||||
for block in self._blocks:
|
||||
print(x.shape)
|
||||
x = b(x)
|
||||
x = block(x)
|
||||
x = swish(self._bn1(x.conv2d(self._conv_head)))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x = x.reshape(shape=(-1, 1280))
|
||||
#x = x.dropout(0.2)
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
return x.dot(self._fc).add(self._fc_bias)
|
||||
|
||||
def load_weights_from_torch(self):
|
||||
# load b0
|
||||
|
@ -116,9 +121,13 @@ if __name__ == "__main__":
|
|||
model = EfficientNet()
|
||||
model.load_weights_from_torch()
|
||||
|
||||
# load cat image and preprocess
|
||||
# load image and preprocess
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"
|
||||
img = Image.open(io.BytesIO(fetch(url)))
|
||||
img = img.resize((398, 224))
|
||||
img = np.array(img)
|
||||
img = img[:, 87:-87]
|
||||
|
@ -143,6 +152,15 @@ if __name__ == "__main__":
|
|||
import time
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(img))
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(out.data[0])
|
||||
plt.show()
|
||||
"""
|
||||
|
||||
print("did inference in %.2f s" % (time.time()-st))
|
||||
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
||||
#print("NOT", np.argmin(out.data), np.min(out.data), lbls[np.argmin(out.data)])
|
||||
|
||||
|
|
Loading…
Reference in New Issue