do test_conv_with_bn test
parent
5495c7d64e
commit
623fb1ef28
|
@ -19,7 +19,6 @@ def fetch_cifar(train=True):
|
||||||
cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1)
|
cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1)
|
||||||
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
|
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
|
||||||
if train:
|
if train:
|
||||||
# TODO: data_batch 2-5
|
|
||||||
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
|
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
|
||||||
else:
|
else:
|
||||||
db = [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]
|
db = [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]
|
||||||
|
|
|
@ -94,13 +94,12 @@ class TestMNIST(unittest.TestCase):
|
||||||
train(model, X_train, Y_train, optimizer, steps=100)
|
train(model, X_train, Y_train, optimizer, steps=100)
|
||||||
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
|
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
|
||||||
|
|
||||||
@unittest.skip("slow and training batchnorm is broken")
|
|
||||||
def test_conv_with_bn(self):
|
def test_conv_with_bn(self):
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
model = TinyConvNet(has_batchnorm=True)
|
model = TinyConvNet(has_batchnorm=True)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.AdamW(model.parameters(), lr=0.003)
|
||||||
train(model, X_train, Y_train, optimizer, steps=100)
|
train(model, X_train, Y_train, optimizer, steps=200)
|
||||||
assert evaluate(model, X_test, Y_test) > 0.7 # TODO: batchnorm doesn't work!!!
|
assert evaluate(model, X_test, Y_test) > 0.94
|
||||||
|
|
||||||
def test_sgd(self):
|
def test_sgd(self):
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
|
|
Loading…
Reference in New Issue