1
0
Fork 0

renaming variables and adding comments

funding
camille garcin 2021-06-30 14:55:52 +02:00
parent 7609839662
commit a04765857b
3 changed files with 109 additions and 109 deletions

137
epoch.py
View File

@ -1,18 +1,22 @@
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from utils import count_correct_top_k, count_correct_average_k, update_correct_per_class, \ from utils import count_correct_topk, count_correct_avgk, update_correct_per_class, \
update_correct_per_class_topk, update_correct_per_class_averagek update_correct_per_class_topk, update_correct_per_class_avgk
import torch.nn.functional as F import torch.nn.functional as F
from collections import defaultdict from collections import defaultdict
def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accuracy, topk_train_accuracy, list_k, n_train, use_gpu): def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train, topk_acc_train, list_k, n_train, use_gpu):
"""Single train epoch pass. At the end of the epoch, updates the lists loss_train, acc_train and topk_acc_train"""
model.train() model.train()
# Initialize variables
loss_epoch_train = 0 loss_epoch_train = 0
n_correct_train = 0 n_correct_train = 0
n_correct_top_k_train = defaultdict(int) # Containers for tracking nb of correctly classified examples (in the top-k sense) and top-k accuracy for each k in list_k
epoch_top_k_accuracy_train = {} n_correct_topk_train = defaultdict(int)
topk_acc_epoch_train = {}
for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)): for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)):
if use_gpu: if use_gpu:
batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda() batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda()
@ -23,40 +27,49 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu
loss_epoch_train += loss_batch_train.item() loss_epoch_train += loss_batch_train.item()
loss_batch_train.backward() loss_batch_train.backward()
optimizer.step() optimizer.step()
# Update variables
with torch.no_grad(): with torch.no_grad():
n_correct_train += torch.sum(torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))).item() n_correct_train += torch.sum(torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))).item()
for k in list_k: for k in list_k:
n_correct_top_k_train[k] += count_correct_top_k(scores=batch_output_train, labels=batch_y_train, n_correct_topk_train[k] += count_correct_topk(scores=batch_output_train, labels=batch_y_train, k=k).item()
k=k).item()
# At the end of epoch compute average of statistics over batches and store them # At the end of epoch compute average of statistics over batches and store them
with torch.no_grad(): with torch.no_grad():
loss_epoch_train /= batch_idx loss_epoch_train /= batch_idx
epoch_accuracy_train = n_correct_train / n_train epoch_accuracy_train = n_correct_train / n_train
for k in list_k: for k in list_k:
epoch_top_k_accuracy_train[k] = n_correct_top_k_train[k] / n_train topk_acc_epoch_train[k] = n_correct_topk_train[k] / n_train
loss_train.append(loss_epoch_train), train_accuracy.append(epoch_accuracy_train), topk_train_accuracy.append( loss_train.append(loss_epoch_train)
epoch_top_k_accuracy_train) acc_train.append(epoch_accuracy_train)
topk_acc_train.append(topk_acc_epoch_train)
return loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train return loss_epoch_train, epoch_accuracy_train, topk_acc_epoch_train
def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy, def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk_acc_val,
class_accuracies_val, list_k, dataset_attributes, use_gpu): class_acc_val, list_k, dataset_attributes, use_gpu):
"""Single val epoch pass.
At the end of the epoch, updates the lists loss_val, acc_val, topk_acc_val and avgk_acc_val"""
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
n_val = dataset_attributes['n_val']
# Initialization of variables
loss_epoch_val = 0 loss_epoch_val = 0
n_correct_val = 0 n_correct_val = 0
n_correct_top_k_val = defaultdict(int) n_correct_topk_val, n_correct_avgk_val = defaultdict(int), defaultdict(int)
n_correct_average_k_val = defaultdict(int) topk_acc_epoch_val, avgk_acc_epoch_val = {}, {}
epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val = {}, {}, {} # Store avg-k threshold for every k in list_k
lmbda_val = {}
# Store class accuracy, and top-k and average-k class accuracy for every k in list_k
class_acc_dict = {} class_acc_dict = {}
class_acc_dict['class_acc'] = defaultdict(int) class_acc_dict['class_acc'] = defaultdict(int)
class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {} class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
for k in list_k: for k in list_k:
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict(int), defaultdict(int) class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
# Store estimated probas and labels of the whole validation set to compute lambda
list_val_proba = [] list_val_proba = []
list_val_labels = [] list_val_labels = []
for batch_idx, (batch_x_val, batch_y_val) in enumerate(tqdm(val_loader, desc='val', position=0)): for batch_idx, (batch_x_val, batch_y_val) in enumerate(tqdm(val_loader, desc='val', position=0)):
@ -64,6 +77,7 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
batch_x_val, batch_y_val = batch_x_val.cuda(), batch_y_val.cuda() batch_x_val, batch_y_val = batch_x_val.cuda(), batch_y_val.cuda()
batch_output_val = model(batch_x_val) batch_output_val = model(batch_x_val)
batch_proba = F.softmax(batch_output_val) batch_proba = F.softmax(batch_output_val)
# Store batch probas and labels
list_val_proba.append(batch_proba) list_val_proba.append(batch_proba)
list_val_labels.append(batch_y_val) list_val_labels.append(batch_y_val)
@ -72,43 +86,48 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu
n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item() n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item()
update_correct_per_class(batch_proba, batch_y_val, class_acc_dict['class_acc']) update_correct_per_class(batch_proba, batch_y_val, class_acc_dict['class_acc'])
# Update top-k count and top-k count for each class
for k in list_k: for k in list_k:
n_correct_top_k_val[k] += count_correct_top_k(scores=batch_output_val, labels=batch_y_val, k=k).item() n_correct_topk_val[k] += count_correct_topk(scores=batch_output_val, labels=batch_y_val, k=k).item()
update_correct_per_class_topk(batch_proba, batch_y_val, class_acc_dict['class_topk_acc'][k], k) update_correct_per_class_topk(batch_proba, batch_y_val, class_acc_dict['class_topk_acc'][k], k)
# Get probas and labels for the entire validation set
val_probas = torch.cat(list_val_proba) val_probas = torch.cat(list_val_proba)
val_labels = torch.cat(list_val_labels) val_labels = torch.cat(list_val_labels)
flat_val_probas = torch.flatten(val_probas) flat_val_probas = torch.flatten(val_probas)
sorted_probas, _ = torch.sort(flat_val_probas, descending=True) sorted_probas, _ = torch.sort(flat_val_probas, descending=True)
for k in list_k: for k in list_k:
lmbda_val[k] = 0.5 * (sorted_probas[dataset_attributes['n_val'] * k - 1] + sorted_probas[dataset_attributes['n_val'] * k]) # Computes threshold for every k and count nb of correctly classifier examples in the avg-k sense (globally and for each class)
n_correct_average_k_val[k] += count_correct_average_k(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item() lmbda_val[k] = 0.5 * (sorted_probas[n_val * k - 1] + sorted_probas[n_val * k])
update_correct_per_class_averagek(val_probas, val_labels, class_acc_dict['class_averagek_acc'][k], lmbda_val[k]) n_correct_avgk_val[k] += count_correct_avgk(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item()
update_correct_per_class_avgk(val_probas, val_labels, class_acc_dict['class_avgk_acc'][k], lmbda_val[k])
# After seeing val update the statistics over batches and store them # After seeing val set update the statistics over batches and store them
loss_epoch_val /= batch_idx loss_epoch_val /= batch_idx
epoch_accuracy_val = n_correct_val / dataset_attributes['n_val'] epoch_accuracy_val = n_correct_val / n_val
# Get top-k acc and avg-k acc
for k in list_k: for k in list_k:
epoch_top_k_accuracy_val[k] = n_correct_top_k_val[k] / dataset_attributes['n_val'] topk_acc_epoch_val[k] = n_correct_topk_val[k] / n_val
epoch_average_k_accuracy_val[k] = n_correct_average_k_val[k] / dataset_attributes['n_val'] avgk_acc_epoch_val[k] = n_correct_avgk_val[k] / n_val
# Get class top-k acc and class avg-k acc
for class_id in class_acc_dict['class_acc'].keys(): for class_id in class_acc_dict['class_acc'].keys():
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['val'][ n_class_val = dataset_attributes['class2num_instances']['val'][class_id]
class_id]
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_val
for k in list_k: for k in list_k:
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['val'][ class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_val
class_id] class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_val
class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \
dataset_attributes['class2num_instances']['val'][
class_id]
loss_val.append(loss_epoch_val), val_accuracy.append(epoch_accuracy_val), topk_val_accuracy.append( # Update containers with current epoch values
epoch_top_k_accuracy_val), averagek_val_accuracy.append(epoch_average_k_accuracy_val), loss_val.append(loss_epoch_val)
class_accuracies_val.append(class_acc_dict) acc_val.append(epoch_accuracy_val)
topk_acc_val.append(topk_acc_epoch_val)
avgk_acc_val.append(avgk_acc_epoch_val)
class_acc_val.append(class_acc_dict)
return loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val return loss_epoch_val, epoch_accuracy_val, topk_acc_epoch_val, avgk_acc_epoch_val, lmbda_val
def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes): def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes):
@ -116,17 +135,17 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
print() print()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
n_test = dataset_attributes['n_test']
loss_epoch_test = 0 loss_epoch_test = 0
n_correct_test = 0 n_correct_test = 0
epoch_top_k_accuracy_test, epoch_average_k_accuracy_test = {}, {} topk_acc_epoch_test, avgk_acc_epoch_test = {}, {}
n_correct_top_k_test, n_correct_average_k_test = defaultdict(int), defaultdict(int) n_correct_topk_test, n_correct_avgk_test = defaultdict(int), defaultdict(int)
class_acc_dict = {} class_acc_dict = {}
class_acc_dict['class_acc'] = defaultdict(int) class_acc_dict['class_acc'] = defaultdict(int)
class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {} class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
for k in list_k: for k in list_k:
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict( class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
int), defaultdict(int)
for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)): for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)):
if use_gpu: if use_gpu:
@ -138,28 +157,24 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item() n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item()
for k in list_k: for k in list_k:
n_correct_top_k_test[k] += count_correct_top_k(scores=batch_output_test, labels=batch_y_test, k=k).item() n_correct_topk_test[k] += count_correct_topk(scores=batch_output_test, labels=batch_y_test, k=k).item()
n_correct_average_k_test[k] += count_correct_average_k(probas=batch_ouput_probra_test, labels=batch_y_test, n_correct_avgk_test[k] += count_correct_avgk(probas=batch_ouput_probra_test, labels=batch_y_test, lmbda=lmbda[k]).item()
lmbda=lmbda[k]).item()
update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k) update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k)
update_correct_per_class_averagek(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_averagek_acc'][k], update_correct_per_class_avgk(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_avgk_acc'][k], lmbda[k])
lmbda[k])
# After seeing test test update the statistics over batches and store them # After seeing test set update the statistics over batches and store them
loss_epoch_test /= batch_idx loss_epoch_test /= batch_idx
epoch_accuracy_test = n_correct_test / dataset_attributes['n_test'] acc_epoch_test = n_correct_test / n_test
for k in list_k: for k in list_k:
epoch_top_k_accuracy_test[k] = n_correct_top_k_test[k] / dataset_attributes['n_test'] topk_acc_epoch_test[k] = n_correct_topk_test[k] / n_test
epoch_average_k_accuracy_test[k] = n_correct_average_k_test[k] / dataset_attributes['n_test'] avgk_acc_epoch_test[k] = n_correct_avgk_test[k] / n_test
for class_id in class_acc_dict['class_acc'].keys(): for class_id in class_acc_dict['class_acc'].keys():
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['test'][ n_class_test = dataset_attributes['class2num_instances']['test'][
class_id] class_id]
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_test
for k in list_k: for k in list_k:
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['test'][ class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_test
class_id] class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_test
class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \
dataset_attributes['class2num_instances']['test'][
class_id]
return loss_epoch_test, epoch_accuracy_test, epoch_top_k_accuracy_test, epoch_average_k_accuracy_test return loss_epoch_test, acc_epoch_test, topk_acc_epoch_test, avgk_acc_epoch_test

68
main.py
View File

@ -2,9 +2,7 @@ import os
from tqdm import tqdm from tqdm import tqdm
import pickle import pickle
import argparse import argparse
import warnings
import time import time
import numpy as np
import torch import torch
from torch.optim import SGD from torch.optim import SGD
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
@ -12,8 +10,6 @@ from torch.nn import CrossEntropyLoss
from utils import set_seed, load_model, save, get_model, update_optimizer, get_data from utils import set_seed, load_model, save, get_model, update_optimizer, get_data
from epoch import train_epoch, val_epoch, test_epoch from epoch import train_epoch, val_epoch, test_epoch
from cli import add_all_parsers from cli import add_all_parsers
# TEMPORARY HACK #
warnings.filterwarnings("ignore")
def train(args): def train(args):
@ -30,11 +26,10 @@ def train(args):
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.mu, nesterov=True) optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.mu, nesterov=True)
# Containers for storing statistics over epochs # Containers for storing metrics over epochs
loss_train, train_accuracy, topk_train_accuracy = [], [], [] loss_train, acc_train, topk_acc_train = [], [], []
loss_val, val_accuracy, topk_val_accuracy, average_k_val_accuracy, class_accuracies_val = [], [], [], [], [] loss_val, acc_val, topk_acc_val, avgk_acc_val, class_acc_val = [], [], [], [], []
best_val_accuracy = np.float('-inf')
save_name = args.save_name_xp.strip() save_name = args.save_name_xp.strip()
save_dir = os.path.join(os.getcwd(), 'results', save_name) save_dir = os.path.join(os.getcwd(), 'results', save_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
@ -43,58 +38,55 @@ def train(args):
print('args.k : ', args.k) print('args.k : ', args.k)
lmbda_best_acc = None lmbda_best_acc = None
best_val_acc = float('-inf')
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0): for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
t = time.time() t = time.time()
optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch) optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch)
loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train = train_epoch(model, optimizer, train_loader, loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
criteria, loss_train, criteria, loss_train, acc_train,
train_accuracy, topk_acc_train, args.k,
topk_train_accuracy, args.k, dataset_attributes['n_train'],
dataset_attributes['n_train'], args.use_gpu)
args.use_gpu)
loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, \ loss_epoch_val, acc_epoch_val, topk_acc_epoch_val, \
epoch_average_k_accuracy_val, lmbda_val = val_epoch(model, val_loader, criteria, avgk_acc_epoch_val, lmbda_val = val_epoch(model, val_loader, criteria,
loss_val, val_accuracy, loss_val, acc_val, topk_acc_val, avgk_acc_val,
topk_val_accuracy, average_k_val_accuracy, class_acc_val, args.k, dataset_attributes, args.use_gpu)
class_accuracies_val, args.k, dataset_attributes,
args.use_gpu)
# no matter what, save model at every epoch # save model at every epoch
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar')) save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar'))
# save model with best val accuracy # save model with best val accuracy
if epoch_accuracy_val > best_val_accuracy: if acc_epoch_val > best_val_acc:
best_val_accuracy = epoch_accuracy_val best_val_acc = acc_epoch_val
lmbda_best_acc = lmbda_val lmbda_best_acc = lmbda_val
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights_best_acc.tar')) save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights_best_acc.tar'))
print() print()
print(f'epoch {epoch} took {time.time()-t:.2f}') print(f'epoch {epoch} took {time.time()-t:.2f}')
print(f'loss_epoch_train : {loss_epoch_train}') print(f'loss_train : {loss_epoch_train}')
print(f'loss_epoch_val : {loss_epoch_val}') print(f'loss_val : {loss_epoch_val}')
print(f'train accuracy : {epoch_accuracy_train} / train top_k accuracy : {epoch_top_k_accuracy_train}') print(f'acc_train : {acc_epoch_train} / topk_acc_train : {topk_acc_epoch_train}')
print(f'val accuracy : {epoch_accuracy_val} / val top_k accuracy : {epoch_top_k_accuracy_val} / ' print(f'acc_val : {acc_epoch_val} / topk_acc_val : {topk_acc_epoch_val} / '
f'val average_k accuracy : {epoch_average_k_accuracy_val}') f'avgk_acc_val : {avgk_acc_epoch_val}')
# load weights corresponding to best val accuracy and evaluate on test # load weights corresponding to best val accuracy and evaluate on test
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu) load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
loss_test_ba, accuracy_test_ba, \ loss_test_ba, acc_test_ba, topk_acc_test_ba, avgk_acc_test_ba = test_epoch(model, test_loader, criteria, args.k,
top_k_accuracy_test_ba, average_k_accuracy_test_ba = test_epoch(model, test_loader, criteria, args.k, lmbda_best_acc, args.use_gpu,
lmbda_best_acc, args.use_gpu, dataset_attributes)
dataset_attributes)
# Save the results as a dictionary and save it as a pickle file in desired location # Save the results as a dictionary and save it as a pickle file in desired location
results = {'loss_train': loss_train, 'train_accuracy': train_accuracy, 'topk_train_accuracy': topk_train_accuracy, results = {'loss_train': loss_train, 'acc_train': acc_train, 'topk_acc_train': topk_acc_train,
'loss_val': loss_val, 'val_accuracy': val_accuracy, 'topk_val_accuracy': topk_val_accuracy, 'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val,
'average_k_val_accuracy': average_k_val_accuracy, 'avgk_acc_val': avgk_acc_val,
'test_results': {'loss': loss_test_ba, 'test_results': {'loss': loss_test_ba,
'accuracy': accuracy_test_ba, 'accuracy': acc_test_ba,
'topk-accuracy': top_k_accuracy_test_ba, 'topk_accuracy': topk_acc_test_ba,
'averagek-accuracy': average_k_accuracy_test_ba}, 'avgk_accuracy': avgk_acc_test_ba},
'params': args.__dict__} 'params': args.__dict__}
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f: with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:

View File

@ -7,7 +7,6 @@ from collections import Counter
from torchvision.models import resnet18, resnet50, resnet101, resnet152 from torchvision.models import resnet18, resnet50, resnet101, resnet152
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms import torchvision.transforms as transforms
@ -39,13 +38,13 @@ def update_correct_per_class_topk(batch_output, batch_y, d, k):
d[true_label.item()] += torch.sum(true_label == predicted_labels).item() d[true_label.item()] += torch.sum(true_label == predicted_labels).item()
def update_correct_per_class_averagek(val_probas, val_labels, d, lmbda): def update_correct_per_class_avgk(val_probas, val_labels, d, lmbda):
ground_truth_probas = torch.gather(val_probas, dim=1, index=val_labels.unsqueeze(-1)) ground_truth_probas = torch.gather(val_probas, dim=1, index=val_labels.unsqueeze(-1))
for true_label, predicted_label in zip(val_labels, ground_truth_probas): for true_label, predicted_label in zip(val_labels, ground_truth_probas):
d[true_label.item()] += (predicted_label >= lmbda).item() d[true_label.item()] += (predicted_label >= lmbda).item()
def count_correct_top_k(scores, labels, k): def count_correct_topk(scores, labels, k):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of """Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense). in the batch (in the top_k accuracy sense).
@ -55,7 +54,7 @@ def count_correct_top_k(scores, labels, k):
return torch.eq(labels, top_k_scores).sum() return torch.eq(labels, top_k_scores).sum()
def count_correct_average_k(probas, labels, lmbda): def count_correct_avgk(probas, labels, lmbda):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of """Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense). in the batch (in the top_k accuracy sense).
@ -65,12 +64,6 @@ def count_correct_average_k(probas, labels, lmbda):
return res return res
def compute_lambda_batch(batch_proba, k):
sorted_probas, _ = torch.sort(torch.flatten(batch_proba), descending=True)
batch_lambda = 0.5 * (sorted_probas[len(batch_proba) * k - 1] + sorted_probas[len(batch_proba) * k])
return batch_lambda
def load_model(model, filename, use_gpu): def load_model(model, filename, use_gpu):
if not os.path.exists(filename): if not os.path.exists(filename):
raise FileNotFoundError raise FileNotFoundError