1
0
Fork 0

adding evaluation for multiple k

funding
camille garcin 2021-06-10 14:39:22 +02:00
parent 81fa56e64a
commit 37b412a699
4 changed files with 46 additions and 31 deletions

View File

@ -16,7 +16,7 @@ If you have installed anaconda, you can run the following command :
In order to train a model on the PlantNet-300K dataset, run the following command :
```python main.py --lr=0.05 --n_epochs=80 --model=resnet50 --root=path_to_data --save_name_xp=xp1```
```python main.py --lr=0.05 --n_epochs=80 --k 1 3 5 10 --model=resnet50 --root=path_to_data --save_name_xp=xp1```
You must provide in the "root" option the path to the train val and test folders.
The "save_name_xp" option is the name of the directory where the weights of the model and the results (metrics) will be stored.

6
cli.py
View File

@ -13,8 +13,6 @@ def add_all_parsers(parser):
def _add_loss_parser(parser):
group_loss = parser.add_argument_group('Loss parameters')
group_loss.add_argument('--k', type=int, help='value of k for computing the topk loss and calculating topk accuracy',
default=3)
group_loss.add_argument('--mu', type=float, default=0., help='weight decay parameter')
@ -23,6 +21,8 @@ def _add_training_parser(parser):
group_training.add_argument('--lr', type=float, help='learning rate to use')
group_training.add_argument('--batch_size', type=int, default=256, help='default is 256')
group_training.add_argument('--n_epochs', type=int)
group_training.add_argument('--k', nargs='+', help='value of k for computing the topk loss and calculating topk accuracy',
required=True, type=int)
def _add_model_parser(parser):
@ -38,7 +38,7 @@ def _add_hardware_parser(parser):
def _add_dataset_parser(parser):
group_dataset = parser.add_argument_group('Dataset parameters')
group_dataset.add_argument('--size_image', type=int, default=128,
group_dataset.add_argument('--size_image', type=int, default=256,
help='size you want to resize the images to')

View File

@ -2,13 +2,15 @@ import torch
from tqdm import tqdm
from utils import count_correct_top_k, count_correct_average_k
import torch.nn.functional as F
from collections import defaultdict
def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accuracy, topk_train_accuracy, k, n_train, use_gpu):
def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accuracy, topk_train_accuracy, list_k, n_train, use_gpu):
model.train()
loss_epoch_train = 0
n_correct_train = 0
n_correct_top_k_train = 0
n_correct_top_k_train = defaultdict(int)
epoch_top_k_accuracy_train = {}
for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)):
if use_gpu:
batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda()
@ -21,13 +23,15 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu
optimizer.step()
with torch.no_grad():
n_correct_train += torch.sum(torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))).item()
n_correct_top_k_train += count_correct_top_k(scores=batch_output_train, labels=batch_y_train,
k=k).item()
for k in list_k:
n_correct_top_k_train[k] += count_correct_top_k(scores=batch_output_train, labels=batch_y_train,
k=k).item()
# At the end of epoch compute average of statistics over batches and store them
with torch.no_grad():
loss_epoch_train /= batch_idx
epoch_accuracy_train = n_correct_train / n_train
epoch_top_k_accuracy_train = n_correct_top_k_train / n_train
for k in list_k:
epoch_top_k_accuracy_train[k] = n_correct_top_k_train[k] / n_train
loss_train.append(loss_epoch_train), train_accuracy.append(epoch_accuracy_train), topk_train_accuracy.append(
epoch_top_k_accuracy_train)
@ -35,14 +39,17 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu
return loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train
def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy, k,
def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy, list_k,
dataset_attributes, use_gpu):
model.eval()
with torch.no_grad():
loss_epoch_val = 0
n_correct_val = 0
n_correct_top_k_val = 0
n_correct_top_k_val = defaultdict(int)
epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val = {}, {}, {}
n_correct_average_k_val = defaultdict(int)
list_val_proba = []
list_val_labels = []
for batch_idx, (batch_x_val, batch_y_val) in enumerate(tqdm(val_loader, desc='val', position=0)):
@ -57,21 +64,24 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
loss_epoch_val += loss_batch_val.item()
n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item()
n_correct_top_k_val += count_correct_top_k(scores=batch_output_val, labels=batch_y_val, k=k).item()
for k in list_k:
n_correct_top_k_val[k] += count_correct_top_k(scores=batch_output_val, labels=batch_y_val, k=k).item()
val_probas = torch.cat(list_val_proba)
val_labels = torch.cat(list_val_labels)
flat_val_probas = torch.flatten(val_probas)
sorted_probas, _ = torch.sort(flat_val_probas, descending=True)
lmbda_val = 0.5 * (sorted_probas[dataset_attributes['n_val'] * k - 1] + sorted_probas[dataset_attributes['n_val'] * k])
n_correct_average_k_val = count_correct_average_k(probas=val_probas, labels=val_labels, lmbda=lmbda_val).item()
for k in list_k:
lmbda_val[k] = 0.5 * (sorted_probas[dataset_attributes['n_val'] * k - 1] + sorted_probas[dataset_attributes['n_val'] * k])
n_correct_average_k_val[k] += count_correct_average_k(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item()
# After seeing val update the statistics over batches and store them
loss_epoch_val /= batch_idx
epoch_accuracy_val = n_correct_val / dataset_attributes['n_val']
epoch_top_k_accuracy_val = n_correct_top_k_val / dataset_attributes['n_val']
epoch_average_k_accuracy_val = n_correct_average_k_val / dataset_attributes['n_val']
for k in list_k:
epoch_top_k_accuracy_val[k] = n_correct_top_k_val[k] / dataset_attributes['n_val']
epoch_average_k_accuracy_val[k] = n_correct_average_k_val[k] / dataset_attributes['n_val']
loss_val.append(loss_epoch_val), val_accuracy.append(epoch_accuracy_val), topk_val_accuracy.append(
epoch_top_k_accuracy_val), averagek_val_accuracy.append(epoch_average_k_accuracy_val)
@ -79,15 +89,15 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
return loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val
def test_epoch(model, test_loader, criteria, k, lmbda, use_gpu, n_test):
def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, n_test):
print()
model.eval()
with torch.no_grad():
loss_epoch_test = 0
n_correct_test = 0
n_correct_top_k_test = 0
n_correct_average_k_test = 0
epoch_top_k_accuracy_test, epoch_average_k_accuracy_test = {}, {}
n_correct_top_k_test, n_correct_average_k_test = defaultdict(int), defaultdict(int)
for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)):
if use_gpu:
batch_x_test, batch_y_test = batch_x_test.cuda(), batch_y_test.cuda()
@ -97,14 +107,16 @@ def test_epoch(model, test_loader, criteria, k, lmbda, use_gpu, n_test):
loss_epoch_test += loss_batch_test.item()
n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item()
n_correct_top_k_test += count_correct_top_k(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_average_k_test += count_correct_average_k(probas=batch_ouput_probra_test, labels=batch_y_test,
lmbda=lmbda).item()
for k in list_k:
n_correct_top_k_test[k] += count_correct_top_k(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_average_k_test[k] += count_correct_average_k(probas=batch_ouput_probra_test, labels=batch_y_test,
lmbda=lmbda[k]).item()
# After seeing test test update the statistics over batches and store them
loss_epoch_test /= batch_idx
epoch_accuracy_test = n_correct_test / n_test
epoch_top_k_accuracy_test = n_correct_top_k_test / n_test
epoch_average_k_accuracy_test = n_correct_average_k_test / n_test
for k in list_k:
epoch_top_k_accuracy_test[k] = n_correct_top_k_test[k] / n_test
epoch_average_k_accuracy_test[k] = n_correct_average_k_test[k] / n_test
return loss_epoch_test, epoch_accuracy_test, epoch_top_k_accuracy_test, epoch_average_k_accuracy_test

17
main.py
View File

@ -40,6 +40,8 @@ def train(args):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print('args.k : ', args.k)
lmbda_best_acc = None
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
@ -72,8 +74,9 @@ def train(args):
print(f'epoch {epoch} took {time.time()-t:.2f}')
print(f'loss_epoch_train : {loss_epoch_train}')
print(f'loss_epoch_val : {loss_epoch_val}')
print(f'train accuracy : {epoch_accuracy_train} / train top_{args.k} accuracy : {epoch_top_k_accuracy_train}')
print(f'val accuracy : {epoch_accuracy_val} / val top_{args.k} accuracy : {epoch_top_k_accuracy_val} / val average_{args.k} accuracy : {epoch_average_k_accuracy_val}')
print(f'train accuracy : {epoch_accuracy_train} / train top_k accuracy : {epoch_top_k_accuracy_train}')
print(f'val accuracy : {epoch_accuracy_val} / val top_k accuracy : {epoch_top_k_accuracy_val} / '
f'val average_k accuracy : {epoch_average_k_accuracy_val}')
# load weights corresponding to best val accuracy and evaluate on test
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
@ -87,11 +90,11 @@ def train(args):
results = {'loss_train': loss_train, 'train_accuracy': train_accuracy, 'topk_train_accuracy': topk_train_accuracy,
'loss_val': loss_val, 'val_accuracy': val_accuracy, 'topk_val_accuracy': topk_val_accuracy,
'average_k_val_accuracy': average_k_val_accuracy,
'test_results': {'best_val_accuracy': {'loss': loss_test_ba,
'accuracy': accuracy_test_ba,
'topk-accuracy': top_k_accuracy_test_ba,
'averagek-accuracy': average_k_accuracy_test_ba}
}, 'params': args.__dict__}
'test_results': {'loss': loss_test_ba,
'accuracy': accuracy_test_ba,
'topk-accuracy': top_k_accuracy_test_ba,
'averagek-accuracy': average_k_accuracy_test_ba},
'params': args.__dict__}
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:
pickle.dump(results, f)