1
0
Fork 0

renaming variables and adding comments

funding
camille garcin 2021-06-30 14:55:52 +02:00
parent 7609839662
commit a04765857b
3 changed files with 109 additions and 109 deletions

137
epoch.py
View File

@ -1,18 +1,22 @@
import torch
from tqdm import tqdm
from utils import count_correct_top_k, count_correct_average_k, update_correct_per_class, \
update_correct_per_class_topk, update_correct_per_class_averagek
from utils import count_correct_topk, count_correct_avgk, update_correct_per_class, \
update_correct_per_class_topk, update_correct_per_class_avgk
import torch.nn.functional as F
from collections import defaultdict
def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accuracy, topk_train_accuracy, list_k, n_train, use_gpu):
def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train, topk_acc_train, list_k, n_train, use_gpu):
"""Single train epoch pass. At the end of the epoch, updates the lists loss_train, acc_train and topk_acc_train"""
model.train()
# Initialize variables
loss_epoch_train = 0
n_correct_train = 0
n_correct_top_k_train = defaultdict(int)
epoch_top_k_accuracy_train = {}
# Containers for tracking nb of correctly classified examples (in the top-k sense) and top-k accuracy for each k in list_k
n_correct_topk_train = defaultdict(int)
topk_acc_epoch_train = {}
for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)):
if use_gpu:
batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda()
@ -23,40 +27,49 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu
loss_epoch_train += loss_batch_train.item()
loss_batch_train.backward()
optimizer.step()
# Update variables
with torch.no_grad():
n_correct_train += torch.sum(torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))).item()
for k in list_k:
n_correct_top_k_train[k] += count_correct_top_k(scores=batch_output_train, labels=batch_y_train,
k=k).item()
n_correct_topk_train[k] += count_correct_topk(scores=batch_output_train, labels=batch_y_train, k=k).item()
# At the end of epoch compute average of statistics over batches and store them
with torch.no_grad():
loss_epoch_train /= batch_idx
epoch_accuracy_train = n_correct_train / n_train
for k in list_k:
epoch_top_k_accuracy_train[k] = n_correct_top_k_train[k] / n_train
topk_acc_epoch_train[k] = n_correct_topk_train[k] / n_train
loss_train.append(loss_epoch_train), train_accuracy.append(epoch_accuracy_train), topk_train_accuracy.append(
epoch_top_k_accuracy_train)
loss_train.append(loss_epoch_train)
acc_train.append(epoch_accuracy_train)
topk_acc_train.append(topk_acc_epoch_train)
return loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train
return loss_epoch_train, epoch_accuracy_train, topk_acc_epoch_train
def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy,
class_accuracies_val, list_k, dataset_attributes, use_gpu):
def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk_acc_val,
class_acc_val, list_k, dataset_attributes, use_gpu):
"""Single val epoch pass.
At the end of the epoch, updates the lists loss_val, acc_val, topk_acc_val and avgk_acc_val"""
model.eval()
with torch.no_grad():
n_val = dataset_attributes['n_val']
# Initialization of variables
loss_epoch_val = 0
n_correct_val = 0
n_correct_top_k_val = defaultdict(int)
n_correct_average_k_val = defaultdict(int)
epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val = {}, {}, {}
n_correct_topk_val, n_correct_avgk_val = defaultdict(int), defaultdict(int)
topk_acc_epoch_val, avgk_acc_epoch_val = {}, {}
# Store avg-k threshold for every k in list_k
lmbda_val = {}
# Store class accuracy, and top-k and average-k class accuracy for every k in list_k
class_acc_dict = {}
class_acc_dict['class_acc'] = defaultdict(int)
class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {}
class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
for k in list_k:
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict(int), defaultdict(int)
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
# Store estimated probas and labels of the whole validation set to compute lambda
list_val_proba = []
list_val_labels = []
for batch_idx, (batch_x_val, batch_y_val) in enumerate(tqdm(val_loader, desc='val', position=0)):
@ -64,6 +77,7 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
batch_x_val, batch_y_val = batch_x_val.cuda(), batch_y_val.cuda()
batch_output_val = model(batch_x_val)
batch_proba = F.softmax(batch_output_val)
# Store batch probas and labels
list_val_proba.append(batch_proba)
list_val_labels.append(batch_y_val)
@ -72,43 +86,48 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item()
update_correct_per_class(batch_proba, batch_y_val, class_acc_dict['class_acc'])
# Update top-k count and top-k count for each class
for k in list_k:
n_correct_top_k_val[k] += count_correct_top_k(scores=batch_output_val, labels=batch_y_val, k=k).item()
n_correct_topk_val[k] += count_correct_topk(scores=batch_output_val, labels=batch_y_val, k=k).item()
update_correct_per_class_topk(batch_proba, batch_y_val, class_acc_dict['class_topk_acc'][k], k)
# Get probas and labels for the entire validation set
val_probas = torch.cat(list_val_proba)
val_labels = torch.cat(list_val_labels)
flat_val_probas = torch.flatten(val_probas)
sorted_probas, _ = torch.sort(flat_val_probas, descending=True)
for k in list_k:
lmbda_val[k] = 0.5 * (sorted_probas[dataset_attributes['n_val'] * k - 1] + sorted_probas[dataset_attributes['n_val'] * k])
n_correct_average_k_val[k] += count_correct_average_k(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item()
update_correct_per_class_averagek(val_probas, val_labels, class_acc_dict['class_averagek_acc'][k], lmbda_val[k])
# Computes threshold for every k and count nb of correctly classifier examples in the avg-k sense (globally and for each class)
lmbda_val[k] = 0.5 * (sorted_probas[n_val * k - 1] + sorted_probas[n_val * k])
n_correct_avgk_val[k] += count_correct_avgk(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item()
update_correct_per_class_avgk(val_probas, val_labels, class_acc_dict['class_avgk_acc'][k], lmbda_val[k])
# After seeing val update the statistics over batches and store them
# After seeing val set update the statistics over batches and store them
loss_epoch_val /= batch_idx
epoch_accuracy_val = n_correct_val / dataset_attributes['n_val']
epoch_accuracy_val = n_correct_val / n_val
# Get top-k acc and avg-k acc
for k in list_k:
epoch_top_k_accuracy_val[k] = n_correct_top_k_val[k] / dataset_attributes['n_val']
epoch_average_k_accuracy_val[k] = n_correct_average_k_val[k] / dataset_attributes['n_val']
topk_acc_epoch_val[k] = n_correct_topk_val[k] / n_val
avgk_acc_epoch_val[k] = n_correct_avgk_val[k] / n_val
# Get class top-k acc and class avg-k acc
for class_id in class_acc_dict['class_acc'].keys():
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['val'][
class_id]
n_class_val = dataset_attributes['class2num_instances']['val'][class_id]
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_val
for k in list_k:
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['val'][
class_id]
class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \
dataset_attributes['class2num_instances']['val'][
class_id]
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_val
class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_val
loss_val.append(loss_epoch_val), val_accuracy.append(epoch_accuracy_val), topk_val_accuracy.append(
epoch_top_k_accuracy_val), averagek_val_accuracy.append(epoch_average_k_accuracy_val),
class_accuracies_val.append(class_acc_dict)
# Update containers with current epoch values
loss_val.append(loss_epoch_val)
acc_val.append(epoch_accuracy_val)
topk_acc_val.append(topk_acc_epoch_val)
avgk_acc_val.append(avgk_acc_epoch_val)
class_acc_val.append(class_acc_dict)
return loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val
return loss_epoch_val, epoch_accuracy_val, topk_acc_epoch_val, avgk_acc_epoch_val, lmbda_val
def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes):
@ -116,17 +135,17 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
print()
model.eval()
with torch.no_grad():
n_test = dataset_attributes['n_test']
loss_epoch_test = 0
n_correct_test = 0
epoch_top_k_accuracy_test, epoch_average_k_accuracy_test = {}, {}
n_correct_top_k_test, n_correct_average_k_test = defaultdict(int), defaultdict(int)
topk_acc_epoch_test, avgk_acc_epoch_test = {}, {}
n_correct_topk_test, n_correct_avgk_test = defaultdict(int), defaultdict(int)
class_acc_dict = {}
class_acc_dict['class_acc'] = defaultdict(int)
class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {}
class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
for k in list_k:
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict(
int), defaultdict(int)
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)):
if use_gpu:
@ -138,28 +157,24 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item()
for k in list_k:
n_correct_top_k_test[k] += count_correct_top_k(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_average_k_test[k] += count_correct_average_k(probas=batch_ouput_probra_test, labels=batch_y_test,
lmbda=lmbda[k]).item()
n_correct_topk_test[k] += count_correct_topk(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_avgk_test[k] += count_correct_avgk(probas=batch_ouput_probra_test, labels=batch_y_test, lmbda=lmbda[k]).item()
update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k)
update_correct_per_class_averagek(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_averagek_acc'][k],
lmbda[k])
update_correct_per_class_avgk(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_avgk_acc'][k], lmbda[k])
# After seeing test test update the statistics over batches and store them
# After seeing test set update the statistics over batches and store them
loss_epoch_test /= batch_idx
epoch_accuracy_test = n_correct_test / dataset_attributes['n_test']
acc_epoch_test = n_correct_test / n_test
for k in list_k:
epoch_top_k_accuracy_test[k] = n_correct_top_k_test[k] / dataset_attributes['n_test']
epoch_average_k_accuracy_test[k] = n_correct_average_k_test[k] / dataset_attributes['n_test']
topk_acc_epoch_test[k] = n_correct_topk_test[k] / n_test
avgk_acc_epoch_test[k] = n_correct_avgk_test[k] / n_test
for class_id in class_acc_dict['class_acc'].keys():
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['test'][
n_class_test = dataset_attributes['class2num_instances']['test'][
class_id]
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_test
for k in list_k:
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['test'][
class_id]
class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \
dataset_attributes['class2num_instances']['test'][
class_id]
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_test
class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_test
return loss_epoch_test, epoch_accuracy_test, epoch_top_k_accuracy_test, epoch_average_k_accuracy_test
return loss_epoch_test, acc_epoch_test, topk_acc_epoch_test, avgk_acc_epoch_test

68
main.py
View File

@ -2,9 +2,7 @@ import os
from tqdm import tqdm
import pickle
import argparse
import warnings
import time
import numpy as np
import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
@ -12,8 +10,6 @@ from torch.nn import CrossEntropyLoss
from utils import set_seed, load_model, save, get_model, update_optimizer, get_data
from epoch import train_epoch, val_epoch, test_epoch
from cli import add_all_parsers
# TEMPORARY HACK #
warnings.filterwarnings("ignore")
def train(args):
@ -30,11 +26,10 @@ def train(args):
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.mu, nesterov=True)
# Containers for storing statistics over epochs
loss_train, train_accuracy, topk_train_accuracy = [], [], []
loss_val, val_accuracy, topk_val_accuracy, average_k_val_accuracy, class_accuracies_val = [], [], [], [], []
# Containers for storing metrics over epochs
loss_train, acc_train, topk_acc_train = [], [], []
loss_val, acc_val, topk_acc_val, avgk_acc_val, class_acc_val = [], [], [], [], []
best_val_accuracy = np.float('-inf')
save_name = args.save_name_xp.strip()
save_dir = os.path.join(os.getcwd(), 'results', save_name)
if not os.path.exists(save_dir):
@ -43,58 +38,55 @@ def train(args):
print('args.k : ', args.k)
lmbda_best_acc = None
best_val_acc = float('-inf')
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
t = time.time()
optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch)
loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train = train_epoch(model, optimizer, train_loader,
criteria, loss_train,
train_accuracy,
topk_train_accuracy, args.k,
dataset_attributes['n_train'],
args.use_gpu)
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
criteria, loss_train, acc_train,
topk_acc_train, args.k,
dataset_attributes['n_train'],
args.use_gpu)
loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, \
epoch_average_k_accuracy_val, lmbda_val = val_epoch(model, val_loader, criteria,
loss_val, val_accuracy,
topk_val_accuracy, average_k_val_accuracy,
class_accuracies_val, args.k, dataset_attributes,
args.use_gpu)
loss_epoch_val, acc_epoch_val, topk_acc_epoch_val, \
avgk_acc_epoch_val, lmbda_val = val_epoch(model, val_loader, criteria,
loss_val, acc_val, topk_acc_val, avgk_acc_val,
class_acc_val, args.k, dataset_attributes, args.use_gpu)
# no matter what, save model at every epoch
# save model at every epoch
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar'))
# save model with best val accuracy
if epoch_accuracy_val > best_val_accuracy:
best_val_accuracy = epoch_accuracy_val
if acc_epoch_val > best_val_acc:
best_val_acc = acc_epoch_val
lmbda_best_acc = lmbda_val
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights_best_acc.tar'))
print()
print(f'epoch {epoch} took {time.time()-t:.2f}')
print(f'loss_epoch_train : {loss_epoch_train}')
print(f'loss_epoch_val : {loss_epoch_val}')
print(f'train accuracy : {epoch_accuracy_train} / train top_k accuracy : {epoch_top_k_accuracy_train}')
print(f'val accuracy : {epoch_accuracy_val} / val top_k accuracy : {epoch_top_k_accuracy_val} / '
f'val average_k accuracy : {epoch_average_k_accuracy_val}')
print(f'loss_train : {loss_epoch_train}')
print(f'loss_val : {loss_epoch_val}')
print(f'acc_train : {acc_epoch_train} / topk_acc_train : {topk_acc_epoch_train}')
print(f'acc_val : {acc_epoch_val} / topk_acc_val : {topk_acc_epoch_val} / '
f'avgk_acc_val : {avgk_acc_epoch_val}')
# load weights corresponding to best val accuracy and evaluate on test
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
loss_test_ba, accuracy_test_ba, \
top_k_accuracy_test_ba, average_k_accuracy_test_ba = test_epoch(model, test_loader, criteria, args.k,
lmbda_best_acc, args.use_gpu,
dataset_attributes)
loss_test_ba, acc_test_ba, topk_acc_test_ba, avgk_acc_test_ba = test_epoch(model, test_loader, criteria, args.k,
lmbda_best_acc, args.use_gpu,
dataset_attributes)
# Save the results as a dictionary and save it as a pickle file in desired location
results = {'loss_train': loss_train, 'train_accuracy': train_accuracy, 'topk_train_accuracy': topk_train_accuracy,
'loss_val': loss_val, 'val_accuracy': val_accuracy, 'topk_val_accuracy': topk_val_accuracy,
'average_k_val_accuracy': average_k_val_accuracy,
results = {'loss_train': loss_train, 'acc_train': acc_train, 'topk_acc_train': topk_acc_train,
'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val,
'avgk_acc_val': avgk_acc_val,
'test_results': {'loss': loss_test_ba,
'accuracy': accuracy_test_ba,
'topk-accuracy': top_k_accuracy_test_ba,
'averagek-accuracy': average_k_accuracy_test_ba},
'accuracy': acc_test_ba,
'topk_accuracy': topk_acc_test_ba,
'avgk_accuracy': avgk_acc_test_ba},
'params': args.__dict__}
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:

View File

@ -7,7 +7,6 @@ from collections import Counter
from torchvision.models import resnet18, resnet50, resnet101, resnet152
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
@ -39,13 +38,13 @@ def update_correct_per_class_topk(batch_output, batch_y, d, k):
d[true_label.item()] += torch.sum(true_label == predicted_labels).item()
def update_correct_per_class_averagek(val_probas, val_labels, d, lmbda):
def update_correct_per_class_avgk(val_probas, val_labels, d, lmbda):
ground_truth_probas = torch.gather(val_probas, dim=1, index=val_labels.unsqueeze(-1))
for true_label, predicted_label in zip(val_labels, ground_truth_probas):
d[true_label.item()] += (predicted_label >= lmbda).item()
def count_correct_top_k(scores, labels, k):
def count_correct_topk(scores, labels, k):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense).
@ -55,7 +54,7 @@ def count_correct_top_k(scores, labels, k):
return torch.eq(labels, top_k_scores).sum()
def count_correct_average_k(probas, labels, lmbda):
def count_correct_avgk(probas, labels, lmbda):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense).
@ -65,12 +64,6 @@ def count_correct_average_k(probas, labels, lmbda):
return res
def compute_lambda_batch(batch_proba, k):
sorted_probas, _ = torch.sort(torch.flatten(batch_proba), descending=True)
batch_lambda = 0.5 * (sorted_probas[len(batch_proba) * k - 1] + sorted_probas[len(batch_proba) * k])
return batch_lambda
def load_model(model, filename, use_gpu):
if not os.path.exists(filename):
raise FileNotFoundError