1
0
Fork 0

fixing class metrics

funding
camille garcin 2021-07-06 15:03:55 +02:00
parent a04765857b
commit 8e9bacb099
3 changed files with 17 additions and 13 deletions

View File

@ -151,16 +151,17 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
if use_gpu:
batch_x_test, batch_y_test = batch_x_test.cuda(), batch_y_test.cuda()
batch_output_test = model(batch_x_test)
batch_ouput_probra_test = F.softmax(batch_output_test)
batch_proba_test = F.softmax(batch_output_test)
loss_batch_test = criteria(batch_output_test, batch_y_test)
loss_epoch_test += loss_batch_test.item()
n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item()
update_correct_per_class(batch_proba_test, batch_y_test, class_acc_dict['class_acc'])
for k in list_k:
n_correct_topk_test[k] += count_correct_topk(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_avgk_test[k] += count_correct_avgk(probas=batch_ouput_probra_test, labels=batch_y_test, lmbda=lmbda[k]).item()
n_correct_avgk_test[k] += count_correct_avgk(probas=batch_proba_test, labels=batch_y_test, lmbda=lmbda[k]).item()
update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k)
update_correct_per_class_avgk(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_avgk_acc'][k], lmbda[k])
update_correct_per_class_avgk(batch_proba_test, batch_y_test, class_acc_dict['class_avgk_acc'][k], lmbda[k])
# After seeing test set update the statistics over batches and store them
loss_epoch_test /= batch_idx
@ -170,11 +171,10 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
avgk_acc_epoch_test[k] = n_correct_avgk_test[k] / n_test
for class_id in class_acc_dict['class_acc'].keys():
n_class_test = dataset_attributes['class2num_instances']['test'][
class_id]
n_class_test = dataset_attributes['class2num_instances']['test'][class_id]
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_test
for k in list_k:
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_test
class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_test
return loss_epoch_test, acc_epoch_test, topk_acc_epoch_test, avgk_acc_epoch_test
return loss_epoch_test, acc_epoch_test, topk_acc_epoch_test, avgk_acc_epoch_test, class_acc_dict

12
main.py
View File

@ -74,19 +74,21 @@ def train(args):
# load weights corresponding to best val accuracy and evaluate on test
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
loss_test_ba, acc_test_ba, topk_acc_test_ba, avgk_acc_test_ba = test_epoch(model, test_loader, criteria, args.k,
lmbda_best_acc, args.use_gpu,
dataset_attributes)
loss_test_ba, acc_test_ba, topk_acc_test_ba, \
avgk_acc_test_ba, class_acc_test = test_epoch(model, test_loader, criteria, args.k,
lmbda_best_acc, args.use_gpu,
dataset_attributes)
# Save the results as a dictionary and save it as a pickle file in desired location
results = {'loss_train': loss_train, 'acc_train': acc_train, 'topk_acc_train': topk_acc_train,
'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val,
'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val, 'class_acc_val': class_acc_val,
'avgk_acc_val': avgk_acc_val,
'test_results': {'loss': loss_test_ba,
'accuracy': acc_test_ba,
'topk_accuracy': topk_acc_test_ba,
'avgk_accuracy': avgk_acc_test_ba},
'avgk_accuracy': avgk_acc_test_ba,
'class_acc_dict': class_acc_test},
'params': args.__dict__}
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:

View File

@ -151,6 +151,7 @@ def get_data(args):
shuffle=True, num_workers=args.num_workers)
testset = Plantnet(args.root, 'test', transform=transform)
test_class_to_num_instances = Counter(testset.targets)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers)
@ -158,9 +159,10 @@ def get_data(args):
n_classes = len(trainset.classes)
dataset_attributes = {'n_train': len(trainset), 'n_val': len(valset), 'n_test': len(testset), 'n_classes': n_classes,
'lr_schedule': [],
'lr_schedule': [40, 50, 60],
'class2num_instances': {'train': train_class_to_num_instances,
'val': val_class_to_num_instances},
'val': val_class_to_num_instances,
'test': test_class_to_num_instances},
'class_to_idx': trainset.class_to_idx}
return trainloader, valloader, testloader, dataset_attributes