diff --git a/epoch.py b/epoch.py index 241e886..72c9b52 100644 --- a/epoch.py +++ b/epoch.py @@ -1,6 +1,8 @@ import torch from tqdm import tqdm -from utils import count_correct_top_k, count_correct_average_k +from utils import count_correct_top_k, count_correct_average_k, update_correct_per_class, \ + update_correct_per_class_topk, update_correct_per_class_averagek + import torch.nn.functional as F from collections import defaultdict @@ -12,6 +14,8 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu n_correct_top_k_train = defaultdict(int) epoch_top_k_accuracy_train = {} for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)): + if batch_idx == 3: + break if use_gpu: batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda() optimizer.zero_grad() @@ -39,16 +43,21 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, train_accu return loss_epoch_train, epoch_accuracy_train, epoch_top_k_accuracy_train -def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy, list_k, - dataset_attributes, use_gpu): +def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, averagek_val_accuracy, + class_accuracies_val, list_k, dataset_attributes, use_gpu): model.eval() with torch.no_grad(): loss_epoch_val = 0 n_correct_val = 0 n_correct_top_k_val = defaultdict(int) - epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val = {}, {}, {} n_correct_average_k_val = defaultdict(int) + epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val = {}, {}, {} + class_acc_dict = {} + class_acc_dict['class_acc'] = defaultdict(int) + class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {} + for k in list_k: + class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict(int), defaultdict(int) list_val_proba = [] list_val_labels = [] @@ -64,8 +73,11 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu loss_epoch_val += loss_batch_val.item() n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item() + update_correct_per_class(batch_proba, batch_y_val, class_acc_dict['class_acc']) + for k in list_k: n_correct_top_k_val[k] += count_correct_top_k(scores=batch_output_val, labels=batch_y_val, k=k).item() + update_correct_per_class_topk(batch_proba, batch_y_val, class_acc_dict['class_topk_acc'][k], k) val_probas = torch.cat(list_val_proba) val_labels = torch.cat(list_val_labels) @@ -75,6 +87,7 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu for k in list_k: lmbda_val[k] = 0.5 * (sorted_probas[dataset_attributes['n_val'] * k - 1] + sorted_probas[dataset_attributes['n_val'] * k]) n_correct_average_k_val[k] += count_correct_average_k(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item() + update_correct_per_class_averagek(val_probas, val_labels, class_acc_dict['class_averagek_acc'][k], lmbda_val[k]) # After seeing val update the statistics over batches and store them loss_epoch_val /= batch_idx @@ -83,13 +96,24 @@ def val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accu epoch_top_k_accuracy_val[k] = n_correct_top_k_val[k] / dataset_attributes['n_val'] epoch_average_k_accuracy_val[k] = n_correct_average_k_val[k] / dataset_attributes['n_val'] + for class_id in class_acc_dict['class_acc'].keys(): + class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['val'][ + class_id] + for k in list_k: + class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['val'][ + class_id] + class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \ + dataset_attributes['class2num_instances']['val'][ + class_id] + loss_val.append(loss_epoch_val), val_accuracy.append(epoch_accuracy_val), topk_val_accuracy.append( - epoch_top_k_accuracy_val), averagek_val_accuracy.append(epoch_average_k_accuracy_val) + epoch_top_k_accuracy_val), averagek_val_accuracy.append(epoch_average_k_accuracy_val), + class_accuracies_val.append(class_acc_dict) return loss_epoch_val, epoch_accuracy_val, epoch_top_k_accuracy_val, epoch_average_k_accuracy_val, lmbda_val -def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, n_test): +def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes): print() model.eval() @@ -98,6 +122,14 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, n_test): n_correct_test = 0 epoch_top_k_accuracy_test, epoch_average_k_accuracy_test = {}, {} n_correct_top_k_test, n_correct_average_k_test = defaultdict(int), defaultdict(int) + + class_acc_dict = {} + class_acc_dict['class_acc'] = defaultdict(int) + class_acc_dict['class_topk_acc'], class_acc_dict['class_averagek_acc'] = {}, {} + for k in list_k: + class_acc_dict['class_topk_acc'][k], class_acc_dict['class_averagek_acc'][k] = defaultdict( + int), defaultdict(int) + for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)): if use_gpu: batch_x_test, batch_y_test = batch_x_test.cuda(), batch_y_test.cuda() @@ -111,12 +143,25 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, n_test): n_correct_top_k_test[k] += count_correct_top_k(scores=batch_output_test, labels=batch_y_test, k=k).item() n_correct_average_k_test[k] += count_correct_average_k(probas=batch_ouput_probra_test, labels=batch_y_test, lmbda=lmbda[k]).item() + update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k) + update_correct_per_class_averagek(batch_ouput_probra_test, batch_y_test, class_acc_dict['class_averagek_acc'][k], + lmbda[k]) # After seeing test test update the statistics over batches and store them loss_epoch_test /= batch_idx - epoch_accuracy_test = n_correct_test / n_test + epoch_accuracy_test = n_correct_test / dataset_attributes['n_test'] for k in list_k: - epoch_top_k_accuracy_test[k] = n_correct_top_k_test[k] / n_test - epoch_average_k_accuracy_test[k] = n_correct_average_k_test[k] / n_test + epoch_top_k_accuracy_test[k] = n_correct_top_k_test[k] / dataset_attributes['n_test'] + epoch_average_k_accuracy_test[k] = n_correct_average_k_test[k] / dataset_attributes['n_test'] + + for class_id in class_acc_dict['class_acc'].keys(): + class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / dataset_attributes['class2num_instances']['test'][ + class_id] + for k in list_k: + class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / dataset_attributes['class2num_instances']['test'][ + class_id] + class_acc_dict['class_averagek_acc'][k][class_id] = class_acc_dict['class_averagek_acc'][k][class_id] / \ + dataset_attributes['class2num_instances']['test'][ + class_id] return loss_epoch_test, epoch_accuracy_test, epoch_top_k_accuracy_test, epoch_average_k_accuracy_test diff --git a/main.py b/main.py index 7810566..4e5eb14 100644 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ def train(args): # Containers for storing statistics over epochs loss_train, train_accuracy, topk_train_accuracy = [], [], [] - loss_val, val_accuracy, topk_val_accuracy, average_k_val_accuracy = [], [], [], [] + loss_val, val_accuracy, topk_val_accuracy, average_k_val_accuracy, class_accuracies_val = [], [], [], [], [] best_val_accuracy = np.float('-inf') save_name = args.save_name_xp.strip() @@ -59,7 +59,8 @@ def train(args): epoch_average_k_accuracy_val, lmbda_val = val_epoch(model, val_loader, criteria, loss_val, val_accuracy, topk_val_accuracy, average_k_val_accuracy, - args.k, dataset_attributes, args.use_gpu) + class_accuracies_val, args.k, dataset_attributes, + args.use_gpu) # no matter what, save model at every epoch save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar')) @@ -83,7 +84,7 @@ def train(args): loss_test_ba, accuracy_test_ba, \ top_k_accuracy_test_ba, average_k_accuracy_test_ba = test_epoch(model, test_loader, criteria, args.k, lmbda_best_acc, args.use_gpu, - dataset_attributes['n_test']) + dataset_attributes) # Save the results as a dictionary and save it as a pickle file in desired location diff --git a/utils.py b/utils.py index 5f1c313..80fda35 100644 --- a/utils.py +++ b/utils.py @@ -24,6 +24,27 @@ def set_seed(args, use_gpu, print_out=True): torch.cuda.manual_seed(args.seed) +def update_correct_per_class(batch_output, batch_y, d): + predicted_class = torch.argmax(batch_output, dim=-1) + for true_label, predicted_label in zip(batch_y, predicted_class): + if true_label == predicted_label: + d[true_label.item()] += 1 + else: + d[true_label.item()] += 0 + + +def update_correct_per_class_topk(batch_output, batch_y, d, k): + topk_labels_pred = torch.argsort(batch_output, axis=-1, descending=True)[:, :k] + for true_label, predicted_labels in zip(batch_y, topk_labels_pred): + d[true_label.item()] += torch.sum(true_label == predicted_labels).item() + + +def update_correct_per_class_averagek(val_probas, val_labels, d, lmbda): + ground_truth_probas = torch.gather(val_probas, dim=1, index=val_labels.unsqueeze(-1)) + for true_label, predicted_label in zip(val_labels, ground_truth_probas): + d[true_label.item()] += (predicted_label >= lmbda).item() + + def count_correct_top_k(scores, labels, k): """Given a tensor of scores of size (n_batch, n_classes) and a tensor of labels of size n_batch, computes the number of correctly predicted exemples