Reformat upstream python with black
parent
5e125e58df
commit
e93c99a890
107
cli.py
107
cli.py
|
@ -36,47 +36,90 @@ def add_all_parsers(parser):
|
|||
|
||||
|
||||
def _add_loss_parser(parser):
|
||||
group_loss = parser.add_argument_group('Loss parameters')
|
||||
group_loss.add_argument('--mu', type=float, default=0.0001, help='weight decay parameter')
|
||||
group_loss = parser.add_argument_group("Loss parameters")
|
||||
group_loss.add_argument(
|
||||
"--mu", type=float, default=0.0001, help="weight decay parameter"
|
||||
)
|
||||
|
||||
|
||||
def _add_training_parser(parser):
|
||||
group_training = parser.add_argument_group('Training parameters')
|
||||
group_training.add_argument('--lr', type=float, help='learning rate to use')
|
||||
group_training.add_argument('--batch_size', type=int, default=32, help='default is 32')
|
||||
group_training.add_argument('--n_epochs', type=int)
|
||||
group_training.add_argument('--pretrained', action='store_true')
|
||||
group_training.add_argument('--image_size', type=int, default=256)
|
||||
group_training.add_argument('--crop_size', type=int, default=224)
|
||||
group_training.add_argument('--epoch_decay', nargs='+', type=int, default=[])
|
||||
group_training.add_argument('--k', nargs='+', help='value of k for computing the topk loss and computing topk accuracy',
|
||||
required=True, type=int)
|
||||
group_training = parser.add_argument_group("Training parameters")
|
||||
group_training.add_argument("--lr", type=float, help="learning rate to use")
|
||||
group_training.add_argument(
|
||||
"--batch_size", type=int, default=32, help="default is 32"
|
||||
)
|
||||
group_training.add_argument("--n_epochs", type=int)
|
||||
group_training.add_argument("--pretrained", action="store_true")
|
||||
group_training.add_argument("--image_size", type=int, default=256)
|
||||
group_training.add_argument("--crop_size", type=int, default=224)
|
||||
group_training.add_argument("--epoch_decay", nargs="+", type=int, default=[])
|
||||
group_training.add_argument(
|
||||
"--k",
|
||||
nargs="+",
|
||||
help="value of k for computing the topk loss and computing topk accuracy",
|
||||
required=True,
|
||||
type=int,
|
||||
)
|
||||
|
||||
|
||||
def _add_model_parser(parser):
|
||||
group_model = parser.add_argument_group('Model parameters')
|
||||
group_model.add_argument('--model', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'densenet121', 'densenet161', 'densenet169', 'densenet201',
|
||||
'mobilenet_v2', 'inception_v3', 'alexnet', 'squeezenet',
|
||||
'shufflenet', 'wide_resnet50_2', 'wide_resnet101_2',
|
||||
'vgg11', 'mobilenet_v3_large', 'mobilenet_v3_small',
|
||||
'inception_resnet_v2', 'inception_v4', 'efficientnet_b0',
|
||||
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
|
||||
'efficientnet_b4', 'vit_base_patch16_224'],
|
||||
default='resnet50', help='choose the model you want to train on')
|
||||
group_model = parser.add_argument_group("Model parameters")
|
||||
group_model.add_argument(
|
||||
"--model",
|
||||
choices=[
|
||||
"resnet18",
|
||||
"resnet34",
|
||||
"resnet50",
|
||||
"resnet101",
|
||||
"resnet152",
|
||||
"densenet121",
|
||||
"densenet161",
|
||||
"densenet169",
|
||||
"densenet201",
|
||||
"mobilenet_v2",
|
||||
"inception_v3",
|
||||
"alexnet",
|
||||
"squeezenet",
|
||||
"shufflenet",
|
||||
"wide_resnet50_2",
|
||||
"wide_resnet101_2",
|
||||
"vgg11",
|
||||
"mobilenet_v3_large",
|
||||
"mobilenet_v3_small",
|
||||
"inception_resnet_v2",
|
||||
"inception_v4",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b1",
|
||||
"efficientnet_b2",
|
||||
"efficientnet_b3",
|
||||
"efficientnet_b4",
|
||||
"vit_base_patch16_224",
|
||||
],
|
||||
default="resnet50",
|
||||
help="choose the model you want to train on",
|
||||
)
|
||||
|
||||
|
||||
def _add_hardware_parser(parser):
|
||||
group_hardware = parser.add_argument_group('Hardware parameters')
|
||||
group_hardware.add_argument('--use_gpu', type=int, choices=[0, 1], default=torch.cuda.is_available())
|
||||
group_hardware = parser.add_argument_group("Hardware parameters")
|
||||
group_hardware.add_argument(
|
||||
"--use_gpu", type=int, choices=[0, 1], default=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
|
||||
def _add_misc_parser(parser):
|
||||
group_misc = parser.add_argument_group('Miscellaneous parameters')
|
||||
group_misc.add_argument('--seed', type=int, help='set the seed for reproductible experiments')
|
||||
group_misc.add_argument('--num_workers', type=int, default=4,
|
||||
help='number of workers for the data loader. Default is one. You can bring it up. '
|
||||
'If you have memory errors go back to one')
|
||||
group_misc.add_argument('--root', help='location of the train val and test directories')
|
||||
group_misc.add_argument('--save_name_xp', help='name of the saving file')
|
||||
|
||||
group_misc = parser.add_argument_group("Miscellaneous parameters")
|
||||
group_misc.add_argument(
|
||||
"--seed", type=int, help="set the seed for reproductible experiments"
|
||||
)
|
||||
group_misc.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="number of workers for the data loader. Default is one. You can bring it up. "
|
||||
"If you have memory errors go back to one",
|
||||
)
|
||||
group_misc.add_argument(
|
||||
"--root", help="location of the train val and test directories"
|
||||
)
|
||||
group_misc.add_argument("--save_name_xp", help="name of the saving file")
|
||||
|
|
196
epoch.py
196
epoch.py
|
@ -26,14 +26,30 @@
|
|||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from utils import count_correct_topk, count_correct_avgk, update_correct_per_class, \
|
||||
update_correct_per_class_topk, update_correct_per_class_avgk
|
||||
from utils import (
|
||||
count_correct_topk,
|
||||
count_correct_avgk,
|
||||
update_correct_per_class,
|
||||
update_correct_per_class_topk,
|
||||
update_correct_per_class_avgk,
|
||||
)
|
||||
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train, topk_acc_train, list_k, n_train, use_gpu):
|
||||
def train_epoch(
|
||||
model,
|
||||
optimizer,
|
||||
train_loader,
|
||||
criteria,
|
||||
loss_train,
|
||||
acc_train,
|
||||
topk_acc_train,
|
||||
list_k,
|
||||
n_train,
|
||||
use_gpu,
|
||||
):
|
||||
"""Single train epoch pass. At the end of the epoch, updates the lists loss_train, acc_train and topk_acc_train"""
|
||||
model.train()
|
||||
# Initialize variables
|
||||
|
@ -43,7 +59,9 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train,
|
|||
n_correct_topk_train = defaultdict(int)
|
||||
topk_acc_epoch_train = {}
|
||||
|
||||
for batch_idx, (batch_x_train, batch_y_train) in enumerate(tqdm(train_loader, desc='train', position=0)):
|
||||
for batch_idx, (batch_x_train, batch_y_train) in enumerate(
|
||||
tqdm(train_loader, desc="train", position=0)
|
||||
):
|
||||
if use_gpu:
|
||||
batch_x_train, batch_y_train = batch_x_train.cuda(), batch_y_train.cuda()
|
||||
optimizer.zero_grad()
|
||||
|
@ -56,9 +74,13 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train,
|
|||
|
||||
# Update variables
|
||||
with torch.no_grad():
|
||||
n_correct_train += torch.sum(torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))).item()
|
||||
n_correct_train += torch.sum(
|
||||
torch.eq(batch_y_train, torch.argmax(batch_output_train, dim=-1))
|
||||
).item()
|
||||
for k in list_k:
|
||||
n_correct_topk_train[k] += count_correct_topk(scores=batch_output_train, labels=batch_y_train, k=k).item()
|
||||
n_correct_topk_train[k] += count_correct_topk(
|
||||
scores=batch_output_train, labels=batch_y_train, k=k
|
||||
).item()
|
||||
|
||||
# At the end of epoch compute average of statistics over batches and store them
|
||||
with torch.no_grad():
|
||||
|
@ -74,14 +96,26 @@ def train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train,
|
|||
return loss_epoch_train, epoch_accuracy_train, topk_acc_epoch_train
|
||||
|
||||
|
||||
def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk_acc_val,
|
||||
class_acc_val, list_k, dataset_attributes, use_gpu):
|
||||
def val_epoch(
|
||||
model,
|
||||
val_loader,
|
||||
criteria,
|
||||
loss_val,
|
||||
acc_val,
|
||||
topk_acc_val,
|
||||
avgk_acc_val,
|
||||
class_acc_val,
|
||||
list_k,
|
||||
dataset_attributes,
|
||||
use_gpu,
|
||||
):
|
||||
"""Single val epoch pass.
|
||||
At the end of the epoch, updates the lists loss_val, acc_val, topk_acc_val and avgk_acc_val"""
|
||||
At the end of the epoch, updates the lists loss_val, acc_val, topk_acc_val and avgk_acc_val
|
||||
"""
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
n_val = dataset_attributes['n_val']
|
||||
n_val = dataset_attributes["n_val"]
|
||||
# Initialization of variables
|
||||
loss_epoch_val = 0
|
||||
n_correct_val = 0
|
||||
|
@ -91,14 +125,19 @@ def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk
|
|||
lmbda_val = {}
|
||||
# Store class accuracy, and top-k and average-k class accuracy for every k in list_k
|
||||
class_acc_dict = {}
|
||||
class_acc_dict['class_acc'] = defaultdict(int)
|
||||
class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
|
||||
class_acc_dict["class_acc"] = defaultdict(int)
|
||||
class_acc_dict["class_topk_acc"], class_acc_dict["class_avgk_acc"] = {}, {}
|
||||
for k in list_k:
|
||||
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
|
||||
(
|
||||
class_acc_dict["class_topk_acc"][k],
|
||||
class_acc_dict["class_avgk_acc"][k],
|
||||
) = defaultdict(int), defaultdict(int)
|
||||
# Store estimated probas and labels of the whole validation set to compute lambda
|
||||
list_val_proba = []
|
||||
list_val_labels = []
|
||||
for batch_idx, (batch_x_val, batch_y_val) in enumerate(tqdm(val_loader, desc='val', position=0)):
|
||||
for batch_idx, (batch_x_val, batch_y_val) in enumerate(
|
||||
tqdm(val_loader, desc="val", position=0)
|
||||
):
|
||||
if use_gpu:
|
||||
batch_x_val, batch_y_val = batch_x_val.cuda(), batch_y_val.cuda()
|
||||
batch_output_val = model(batch_x_val)
|
||||
|
@ -110,12 +149,20 @@ def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk
|
|||
loss_batch_val = criteria(batch_output_val, batch_y_val)
|
||||
loss_epoch_val += loss_batch_val.item()
|
||||
|
||||
n_correct_val += torch.sum(torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))).item()
|
||||
update_correct_per_class(batch_proba, batch_y_val, class_acc_dict['class_acc'])
|
||||
n_correct_val += torch.sum(
|
||||
torch.eq(batch_y_val, torch.argmax(batch_output_val, dim=-1))
|
||||
).item()
|
||||
update_correct_per_class(
|
||||
batch_proba, batch_y_val, class_acc_dict["class_acc"]
|
||||
)
|
||||
# Update top-k count and top-k count for each class
|
||||
for k in list_k:
|
||||
n_correct_topk_val[k] += count_correct_topk(scores=batch_output_val, labels=batch_y_val, k=k).item()
|
||||
update_correct_per_class_topk(batch_proba, batch_y_val, class_acc_dict['class_topk_acc'][k], k)
|
||||
n_correct_topk_val[k] += count_correct_topk(
|
||||
scores=batch_output_val, labels=batch_y_val, k=k
|
||||
).item()
|
||||
update_correct_per_class_topk(
|
||||
batch_proba, batch_y_val, class_acc_dict["class_topk_acc"][k], k
|
||||
)
|
||||
|
||||
# Get probas and labels for the entire validation set
|
||||
val_probas = torch.cat(list_val_proba)
|
||||
|
@ -126,9 +173,18 @@ def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk
|
|||
|
||||
for k in list_k:
|
||||
# Computes threshold for every k and count nb of correctly classifier examples in the avg-k sense (globally and for each class)
|
||||
lmbda_val[k] = 0.5 * (sorted_probas[n_val * k - 1] + sorted_probas[n_val * k])
|
||||
n_correct_avgk_val[k] += count_correct_avgk(probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]).item()
|
||||
update_correct_per_class_avgk(val_probas, val_labels, class_acc_dict['class_avgk_acc'][k], lmbda_val[k])
|
||||
lmbda_val[k] = 0.5 * (
|
||||
sorted_probas[n_val * k - 1] + sorted_probas[n_val * k]
|
||||
)
|
||||
n_correct_avgk_val[k] += count_correct_avgk(
|
||||
probas=val_probas, labels=val_labels, lmbda=lmbda_val[k]
|
||||
).item()
|
||||
update_correct_per_class_avgk(
|
||||
val_probas,
|
||||
val_labels,
|
||||
class_acc_dict["class_avgk_acc"][k],
|
||||
lmbda_val[k],
|
||||
)
|
||||
|
||||
# After seeing val set update the statistics over batches and store them
|
||||
loss_epoch_val /= batch_idx
|
||||
|
@ -138,13 +194,19 @@ def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk
|
|||
topk_acc_epoch_val[k] = n_correct_topk_val[k] / n_val
|
||||
avgk_acc_epoch_val[k] = n_correct_avgk_val[k] / n_val
|
||||
# Get class top-k acc and class avg-k acc
|
||||
for class_id in class_acc_dict['class_acc'].keys():
|
||||
n_class_val = dataset_attributes['class2num_instances']['val'][class_id]
|
||||
for class_id in class_acc_dict["class_acc"].keys():
|
||||
n_class_val = dataset_attributes["class2num_instances"]["val"][class_id]
|
||||
|
||||
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_val
|
||||
class_acc_dict["class_acc"][class_id] = (
|
||||
class_acc_dict["class_acc"][class_id] / n_class_val
|
||||
)
|
||||
for k in list_k:
|
||||
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_val
|
||||
class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_val
|
||||
class_acc_dict["class_topk_acc"][k][class_id] = (
|
||||
class_acc_dict["class_topk_acc"][k][class_id] / n_class_val
|
||||
)
|
||||
class_acc_dict["class_avgk_acc"][k][class_id] = (
|
||||
class_acc_dict["class_avgk_acc"][k][class_id] / n_class_val
|
||||
)
|
||||
|
||||
# Update containers with current epoch values
|
||||
loss_val.append(loss_epoch_val)
|
||||
|
@ -153,27 +215,39 @@ def val_epoch(model, val_loader, criteria, loss_val, acc_val, topk_acc_val, avgk
|
|||
avgk_acc_val.append(avgk_acc_epoch_val)
|
||||
class_acc_val.append(class_acc_dict)
|
||||
|
||||
return loss_epoch_val, epoch_accuracy_val, topk_acc_epoch_val, avgk_acc_epoch_val, lmbda_val
|
||||
return (
|
||||
loss_epoch_val,
|
||||
epoch_accuracy_val,
|
||||
topk_acc_epoch_val,
|
||||
avgk_acc_epoch_val,
|
||||
lmbda_val,
|
||||
)
|
||||
|
||||
|
||||
def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes):
|
||||
|
||||
def test_epoch(
|
||||
model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_attributes
|
||||
):
|
||||
print()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
n_test = dataset_attributes['n_test']
|
||||
n_test = dataset_attributes["n_test"]
|
||||
loss_epoch_test = 0
|
||||
n_correct_test = 0
|
||||
topk_acc_epoch_test, avgk_acc_epoch_test = {}, {}
|
||||
n_correct_topk_test, n_correct_avgk_test = defaultdict(int), defaultdict(int)
|
||||
|
||||
class_acc_dict = {}
|
||||
class_acc_dict['class_acc'] = defaultdict(int)
|
||||
class_acc_dict['class_topk_acc'], class_acc_dict['class_avgk_acc'] = {}, {}
|
||||
class_acc_dict["class_acc"] = defaultdict(int)
|
||||
class_acc_dict["class_topk_acc"], class_acc_dict["class_avgk_acc"] = {}, {}
|
||||
for k in list_k:
|
||||
class_acc_dict['class_topk_acc'][k], class_acc_dict['class_avgk_acc'][k] = defaultdict(int), defaultdict(int)
|
||||
(
|
||||
class_acc_dict["class_topk_acc"][k],
|
||||
class_acc_dict["class_avgk_acc"][k],
|
||||
) = defaultdict(int), defaultdict(int)
|
||||
|
||||
for batch_idx, (batch_x_test, batch_y_test) in enumerate(tqdm(test_loader, desc='test', position=0)):
|
||||
for batch_idx, (batch_x_test, batch_y_test) in enumerate(
|
||||
tqdm(test_loader, desc="test", position=0)
|
||||
):
|
||||
if use_gpu:
|
||||
batch_x_test, batch_y_test = batch_x_test.cuda(), batch_y_test.cuda()
|
||||
batch_output_test = model(batch_x_test)
|
||||
|
@ -181,13 +255,31 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
|
|||
loss_batch_test = criteria(batch_output_test, batch_y_test)
|
||||
loss_epoch_test += loss_batch_test.item()
|
||||
|
||||
n_correct_test += torch.sum(torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))).item()
|
||||
update_correct_per_class(batch_proba_test, batch_y_test, class_acc_dict['class_acc'])
|
||||
n_correct_test += torch.sum(
|
||||
torch.eq(batch_y_test, torch.argmax(batch_output_test, dim=-1))
|
||||
).item()
|
||||
update_correct_per_class(
|
||||
batch_proba_test, batch_y_test, class_acc_dict["class_acc"]
|
||||
)
|
||||
for k in list_k:
|
||||
n_correct_topk_test[k] += count_correct_topk(scores=batch_output_test, labels=batch_y_test, k=k).item()
|
||||
n_correct_avgk_test[k] += count_correct_avgk(probas=batch_proba_test, labels=batch_y_test, lmbda=lmbda[k]).item()
|
||||
update_correct_per_class_topk(batch_output_test, batch_y_test, class_acc_dict['class_topk_acc'][k], k)
|
||||
update_correct_per_class_avgk(batch_proba_test, batch_y_test, class_acc_dict['class_avgk_acc'][k], lmbda[k])
|
||||
n_correct_topk_test[k] += count_correct_topk(
|
||||
scores=batch_output_test, labels=batch_y_test, k=k
|
||||
).item()
|
||||
n_correct_avgk_test[k] += count_correct_avgk(
|
||||
probas=batch_proba_test, labels=batch_y_test, lmbda=lmbda[k]
|
||||
).item()
|
||||
update_correct_per_class_topk(
|
||||
batch_output_test,
|
||||
batch_y_test,
|
||||
class_acc_dict["class_topk_acc"][k],
|
||||
k,
|
||||
)
|
||||
update_correct_per_class_avgk(
|
||||
batch_proba_test,
|
||||
batch_y_test,
|
||||
class_acc_dict["class_avgk_acc"][k],
|
||||
lmbda[k],
|
||||
)
|
||||
|
||||
# After seeing test set update the statistics over batches and store them
|
||||
loss_epoch_test /= batch_idx
|
||||
|
@ -196,11 +288,23 @@ def test_epoch(model, test_loader, criteria, list_k, lmbda, use_gpu, dataset_att
|
|||
topk_acc_epoch_test[k] = n_correct_topk_test[k] / n_test
|
||||
avgk_acc_epoch_test[k] = n_correct_avgk_test[k] / n_test
|
||||
|
||||
for class_id in class_acc_dict['class_acc'].keys():
|
||||
n_class_test = dataset_attributes['class2num_instances']['test'][class_id]
|
||||
class_acc_dict['class_acc'][class_id] = class_acc_dict['class_acc'][class_id] / n_class_test
|
||||
for class_id in class_acc_dict["class_acc"].keys():
|
||||
n_class_test = dataset_attributes["class2num_instances"]["test"][class_id]
|
||||
class_acc_dict["class_acc"][class_id] = (
|
||||
class_acc_dict["class_acc"][class_id] / n_class_test
|
||||
)
|
||||
for k in list_k:
|
||||
class_acc_dict['class_topk_acc'][k][class_id] = class_acc_dict['class_topk_acc'][k][class_id] / n_class_test
|
||||
class_acc_dict['class_avgk_acc'][k][class_id] = class_acc_dict['class_avgk_acc'][k][class_id] / n_class_test
|
||||
class_acc_dict["class_topk_acc"][k][class_id] = (
|
||||
class_acc_dict["class_topk_acc"][k][class_id] / n_class_test
|
||||
)
|
||||
class_acc_dict["class_avgk_acc"][k][class_id] = (
|
||||
class_acc_dict["class_avgk_acc"][k][class_id] / n_class_test
|
||||
)
|
||||
|
||||
return loss_epoch_test, acc_epoch_test, topk_acc_epoch_test, avgk_acc_epoch_test, class_acc_dict
|
||||
return (
|
||||
loss_epoch_test,
|
||||
acc_epoch_test,
|
||||
topk_acc_epoch_test,
|
||||
avgk_acc_epoch_test,
|
||||
class_acc_dict,
|
||||
)
|
||||
|
|
155
main.py
155
main.py
|
@ -40,90 +40,159 @@ from cli import add_all_parsers
|
|||
|
||||
def train(args):
|
||||
set_seed(args, use_gpu=torch.cuda.is_available())
|
||||
train_loader, val_loader, test_loader, dataset_attributes = get_data(args.root, args.image_size, args.crop_size,
|
||||
args.batch_size, args.num_workers, args.pretrained)
|
||||
train_loader, val_loader, test_loader, dataset_attributes = get_data(
|
||||
args.root,
|
||||
args.image_size,
|
||||
args.crop_size,
|
||||
args.batch_size,
|
||||
args.num_workers,
|
||||
args.pretrained,
|
||||
)
|
||||
|
||||
model = get_model(args, n_classes=dataset_attributes['n_classes'])
|
||||
model = get_model(args, n_classes=dataset_attributes["n_classes"])
|
||||
criteria = CrossEntropyLoss()
|
||||
|
||||
if args.use_gpu:
|
||||
print('USING GPU')
|
||||
print("USING GPU")
|
||||
torch.cuda.set_device(0)
|
||||
model.cuda()
|
||||
criteria.cuda()
|
||||
|
||||
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.mu, nesterov=True)
|
||||
optimizer = SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
momentum=0.9,
|
||||
weight_decay=args.mu,
|
||||
nesterov=True,
|
||||
)
|
||||
|
||||
# Containers for storing metrics over epochs
|
||||
loss_train, acc_train, topk_acc_train = [], [], []
|
||||
loss_val, acc_val, topk_acc_val, avgk_acc_val, class_acc_val = [], [], [], [], []
|
||||
|
||||
save_name = args.save_name_xp.strip()
|
||||
save_dir = os.path.join(os.getcwd(), 'results', save_name)
|
||||
save_dir = os.path.join(os.getcwd(), "results", save_name)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
print('args.k : ', args.k)
|
||||
print("args.k : ", args.k)
|
||||
|
||||
lmbda_best_acc = None
|
||||
best_val_acc = float('-inf')
|
||||
best_val_acc = float("-inf")
|
||||
|
||||
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
|
||||
for epoch in tqdm(range(args.n_epochs), desc="epoch", position=0):
|
||||
t = time.time()
|
||||
optimizer = update_optimizer(optimizer, lr_schedule=args.epoch_decay, epoch=epoch)
|
||||
optimizer = update_optimizer(
|
||||
optimizer, lr_schedule=args.epoch_decay, epoch=epoch
|
||||
)
|
||||
|
||||
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
|
||||
criteria, loss_train, acc_train,
|
||||
topk_acc_train, args.k,
|
||||
dataset_attributes['n_train'],
|
||||
args.use_gpu)
|
||||
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(
|
||||
model,
|
||||
optimizer,
|
||||
train_loader,
|
||||
criteria,
|
||||
loss_train,
|
||||
acc_train,
|
||||
topk_acc_train,
|
||||
args.k,
|
||||
dataset_attributes["n_train"],
|
||||
args.use_gpu,
|
||||
)
|
||||
|
||||
loss_epoch_val, acc_epoch_val, topk_acc_epoch_val, \
|
||||
avgk_acc_epoch_val, lmbda_val = val_epoch(model, val_loader, criteria,
|
||||
loss_val, acc_val, topk_acc_val, avgk_acc_val,
|
||||
class_acc_val, args.k, dataset_attributes, args.use_gpu)
|
||||
(
|
||||
loss_epoch_val,
|
||||
acc_epoch_val,
|
||||
topk_acc_epoch_val,
|
||||
avgk_acc_epoch_val,
|
||||
lmbda_val,
|
||||
) = val_epoch(
|
||||
model,
|
||||
val_loader,
|
||||
criteria,
|
||||
loss_val,
|
||||
acc_val,
|
||||
topk_acc_val,
|
||||
avgk_acc_val,
|
||||
class_acc_val,
|
||||
args.k,
|
||||
dataset_attributes,
|
||||
args.use_gpu,
|
||||
)
|
||||
|
||||
# save model at every epoch
|
||||
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar'))
|
||||
save(
|
||||
model, optimizer, epoch, os.path.join(save_dir, save_name + "_weights.tar")
|
||||
)
|
||||
|
||||
# save model with best val accuracy
|
||||
if acc_epoch_val > best_val_acc:
|
||||
best_val_acc = acc_epoch_val
|
||||
lmbda_best_acc = lmbda_val
|
||||
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights_best_acc.tar'))
|
||||
save(
|
||||
model,
|
||||
optimizer,
|
||||
epoch,
|
||||
os.path.join(save_dir, save_name + "_weights_best_acc.tar"),
|
||||
)
|
||||
|
||||
print()
|
||||
print(f'epoch {epoch} took {time.time()-t:.2f}')
|
||||
print(f'loss_train : {loss_epoch_train}')
|
||||
print(f'loss_val : {loss_epoch_val}')
|
||||
print(f'acc_train : {acc_epoch_train} / topk_acc_train : {topk_acc_epoch_train}')
|
||||
print(f'acc_val : {acc_epoch_val} / topk_acc_val : {topk_acc_epoch_val} / '
|
||||
f'avgk_acc_val : {avgk_acc_epoch_val}')
|
||||
print(f"epoch {epoch} took {time.time()-t:.2f}")
|
||||
print(f"loss_train : {loss_epoch_train}")
|
||||
print(f"loss_val : {loss_epoch_val}")
|
||||
print(
|
||||
f"acc_train : {acc_epoch_train} / topk_acc_train : {topk_acc_epoch_train}"
|
||||
)
|
||||
print(
|
||||
f"acc_val : {acc_epoch_val} / topk_acc_val : {topk_acc_epoch_val} / "
|
||||
f"avgk_acc_val : {avgk_acc_epoch_val}"
|
||||
)
|
||||
|
||||
# load weights corresponding to best val accuracy and evaluate on test
|
||||
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
|
||||
loss_test_ba, acc_test_ba, topk_acc_test_ba, \
|
||||
avgk_acc_test_ba, class_acc_test = test_epoch(model, test_loader, criteria, args.k,
|
||||
lmbda_best_acc, args.use_gpu,
|
||||
dataset_attributes)
|
||||
load_model(
|
||||
model, os.path.join(save_dir, save_name + "_weights_best_acc.tar"), args.use_gpu
|
||||
)
|
||||
(
|
||||
loss_test_ba,
|
||||
acc_test_ba,
|
||||
topk_acc_test_ba,
|
||||
avgk_acc_test_ba,
|
||||
class_acc_test,
|
||||
) = test_epoch(
|
||||
model,
|
||||
test_loader,
|
||||
criteria,
|
||||
args.k,
|
||||
lmbda_best_acc,
|
||||
args.use_gpu,
|
||||
dataset_attributes,
|
||||
)
|
||||
|
||||
# Save the results as a dictionary and save it as a pickle file in desired location
|
||||
|
||||
results = {'loss_train': loss_train, 'acc_train': acc_train, 'topk_acc_train': topk_acc_train,
|
||||
'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val, 'class_acc_val': class_acc_val,
|
||||
'avgk_acc_val': avgk_acc_val,
|
||||
'test_results': {'loss': loss_test_ba,
|
||||
'accuracy': acc_test_ba,
|
||||
'topk_accuracy': topk_acc_test_ba,
|
||||
'avgk_accuracy': avgk_acc_test_ba,
|
||||
'class_acc_dict': class_acc_test},
|
||||
'params': args.__dict__}
|
||||
results = {
|
||||
"loss_train": loss_train,
|
||||
"acc_train": acc_train,
|
||||
"topk_acc_train": topk_acc_train,
|
||||
"loss_val": loss_val,
|
||||
"acc_val": acc_val,
|
||||
"topk_acc_val": topk_acc_val,
|
||||
"class_acc_val": class_acc_val,
|
||||
"avgk_acc_val": avgk_acc_val,
|
||||
"test_results": {
|
||||
"loss": loss_test_ba,
|
||||
"accuracy": acc_test_ba,
|
||||
"topk_accuracy": topk_acc_test_ba,
|
||||
"avgk_accuracy": avgk_acc_test_ba,
|
||||
"class_acc_dict": class_acc_test,
|
||||
},
|
||||
"params": args.__dict__,
|
||||
}
|
||||
|
||||
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:
|
||||
with open(os.path.join(save_dir, save_name + ".pkl"), "wb") as f:
|
||||
pickle.dump(results, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
add_all_parsers(parser)
|
||||
args = parser.parse_args()
|
||||
|
|
220
utils.py
220
utils.py
|
@ -33,9 +33,27 @@ import os
|
|||
from collections import Counter
|
||||
|
||||
|
||||
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, inception_v3, mobilenet_v2, densenet121, \
|
||||
densenet161, densenet169, densenet201, alexnet, squeezenet1_0, shufflenet_v2_x1_0, wide_resnet50_2, wide_resnet101_2,\
|
||||
vgg11, mobilenet_v3_large, mobilenet_v3_small
|
||||
from torchvision.models import (
|
||||
resnet18,
|
||||
resnet34,
|
||||
resnet50,
|
||||
resnet101,
|
||||
resnet152,
|
||||
inception_v3,
|
||||
mobilenet_v2,
|
||||
densenet121,
|
||||
densenet161,
|
||||
densenet169,
|
||||
densenet201,
|
||||
alexnet,
|
||||
squeezenet1_0,
|
||||
shufflenet_v2_x1_0,
|
||||
wide_resnet50_2,
|
||||
wide_resnet101_2,
|
||||
vgg11,
|
||||
mobilenet_v3_large,
|
||||
mobilenet_v3_small,
|
||||
)
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
@ -44,7 +62,7 @@ from torchvision.transforms import CenterCrop
|
|||
|
||||
def set_seed(args, use_gpu, print_out=True):
|
||||
if print_out:
|
||||
print('Seed:\t {}'.format(args.seed))
|
||||
print("Seed:\t {}".format(args.seed))
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
@ -68,7 +86,9 @@ def update_correct_per_class_topk(batch_output, batch_y, d, k):
|
|||
|
||||
|
||||
def update_correct_per_class_avgk(val_probas, val_labels, d, lmbda):
|
||||
ground_truth_probas = torch.gather(val_probas, dim=1, index=val_labels.unsqueeze(-1))
|
||||
ground_truth_probas = torch.gather(
|
||||
val_probas, dim=1, index=val_labels.unsqueeze(-1)
|
||||
)
|
||||
for true_label, predicted_label in zip(val_labels, ground_truth_probas):
|
||||
d[true_label.item()] += (predicted_label >= lmbda).item()
|
||||
|
||||
|
@ -97,19 +117,19 @@ def load_model(model, filename, use_gpu):
|
|||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError
|
||||
|
||||
device = 'cuda:0' if use_gpu else 'cpu'
|
||||
device = "cuda:0" if use_gpu else "cpu"
|
||||
d = torch.load(filename, map_location=device)
|
||||
model.load_state_dict(d['model'])
|
||||
return d['epoch']
|
||||
model.load_state_dict(d["model"])
|
||||
return d["epoch"]
|
||||
|
||||
|
||||
def load_optimizer(optimizer, filename, use_gpu):
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError
|
||||
|
||||
device = 'cuda:0' if use_gpu else 'cpu'
|
||||
device = "cuda:0" if use_gpu else "cpu"
|
||||
d = torch.load(filename, map_location=device)
|
||||
optimizer.load_state_dict(d['optimizer'])
|
||||
optimizer.load_state_dict(d["optimizer"])
|
||||
|
||||
|
||||
def save(model, optimizer, epoch, location):
|
||||
|
@ -117,16 +137,18 @@ def save(model, optimizer, epoch, location):
|
|||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
|
||||
d = {'epoch': epoch,
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
d = {
|
||||
"epoch": epoch,
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(d, location)
|
||||
|
||||
|
||||
def decay_lr(optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] *= 0.1
|
||||
print('Switching lr to {}'.format(optimizer.param_groups[0]['lr']))
|
||||
param_group["lr"] *= 0.1
|
||||
print("Switching lr to {}".format(optimizer.param_groups[0]["lr"]))
|
||||
return optimizer
|
||||
|
||||
|
||||
|
@ -137,55 +159,90 @@ def update_optimizer(optimizer, lr_schedule, epoch):
|
|||
|
||||
|
||||
def get_model(args, n_classes):
|
||||
pytorch_models = {'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50, 'resnet101': resnet101,
|
||||
'resnet152': resnet152, 'densenet121': densenet121, 'densenet161': densenet161,
|
||||
'densenet169': densenet169, 'densenet201': densenet201, 'mobilenet_v2': mobilenet_v2,
|
||||
'inception_v3': inception_v3, 'alexnet': alexnet, 'squeezenet': squeezenet1_0,
|
||||
'shufflenet': shufflenet_v2_x1_0, 'wide_resnet50_2': wide_resnet50_2,
|
||||
'wide_resnet101_2': wide_resnet101_2, 'vgg11': vgg11, 'mobilenet_v3_large': mobilenet_v3_large,
|
||||
'mobilenet_v3_small': mobilenet_v3_small
|
||||
pytorch_models = {
|
||||
"resnet18": resnet18,
|
||||
"resnet34": resnet34,
|
||||
"resnet50": resnet50,
|
||||
"resnet101": resnet101,
|
||||
"resnet152": resnet152,
|
||||
"densenet121": densenet121,
|
||||
"densenet161": densenet161,
|
||||
"densenet169": densenet169,
|
||||
"densenet201": densenet201,
|
||||
"mobilenet_v2": mobilenet_v2,
|
||||
"inception_v3": inception_v3,
|
||||
"alexnet": alexnet,
|
||||
"squeezenet": squeezenet1_0,
|
||||
"shufflenet": shufflenet_v2_x1_0,
|
||||
"wide_resnet50_2": wide_resnet50_2,
|
||||
"wide_resnet101_2": wide_resnet101_2,
|
||||
"vgg11": vgg11,
|
||||
"mobilenet_v3_large": mobilenet_v3_large,
|
||||
"mobilenet_v3_small": mobilenet_v3_small,
|
||||
}
|
||||
timm_models = {
|
||||
"inception_resnet_v2",
|
||||
"inception_v4",
|
||||
"efficientnet_b0",
|
||||
"efficientnet_b1",
|
||||
"efficientnet_b2",
|
||||
"efficientnet_b3",
|
||||
"efficientnet_b4",
|
||||
"vit_base_patch16_224",
|
||||
}
|
||||
timm_models = {'inception_resnet_v2', 'inception_v4', 'efficientnet_b0', 'efficientnet_b1',
|
||||
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'vit_base_patch16_224'}
|
||||
|
||||
if args.model in pytorch_models.keys() and not args.pretrained:
|
||||
if args.model == 'inception_v3':
|
||||
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes, aux_logits=False)
|
||||
if args.model == "inception_v3":
|
||||
model = pytorch_models[args.model](
|
||||
pretrained=False, num_classes=n_classes, aux_logits=False
|
||||
)
|
||||
else:
|
||||
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes)
|
||||
elif args.model in pytorch_models.keys() and args.pretrained:
|
||||
if args.model in {'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2',
|
||||
'wide_resnet101_2', 'shufflenet'}:
|
||||
if args.model in {
|
||||
"resnet18",
|
||||
"resnet34",
|
||||
"resnet50",
|
||||
"resnet101",
|
||||
"resnet152",
|
||||
"wide_resnet50_2",
|
||||
"wide_resnet101_2",
|
||||
"shufflenet",
|
||||
}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model in {'alexnet', 'vgg11'}:
|
||||
elif args.model in {"alexnet", "vgg11"}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[6].in_features
|
||||
model.classifier[6] = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model in {'densenet121', 'densenet161', 'densenet169', 'densenet201'}:
|
||||
elif args.model in {"densenet121", "densenet161", "densenet169", "densenet201"}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'mobilenet_v2':
|
||||
elif args.model == "mobilenet_v2":
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[1].in_features
|
||||
model.classifier[1] = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'inception_v3':
|
||||
elif args.model == "inception_v3":
|
||||
model = inception_v3(pretrained=True, aux_logits=False)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'squeezenet':
|
||||
elif args.model == "squeezenet":
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
model.classifier[1] = nn.Conv2d(512, n_classes, kernel_size=(1, 1), stride=(1, 1))
|
||||
model.classifier[1] = nn.Conv2d(
|
||||
512, n_classes, kernel_size=(1, 1), stride=(1, 1)
|
||||
)
|
||||
model.num_classes = n_classes
|
||||
elif args.model == 'mobilenet_v3_large' or args.model == 'mobilenet_v3_small':
|
||||
elif args.model == "mobilenet_v3_large" or args.model == "mobilenet_v3_small":
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, n_classes)
|
||||
|
||||
elif args.model in timm_models:
|
||||
model = timm.create_model(args.model, pretrained=args.pretrained, num_classes=n_classes)
|
||||
model = timm.create_model(
|
||||
args.model, pretrained=args.pretrained, num_classes=n_classes
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -204,44 +261,81 @@ class Plantnet(ImageFolder):
|
|||
|
||||
|
||||
def get_data(root, image_size, crop_size, batch_size, num_workers, pretrained):
|
||||
|
||||
if pretrained:
|
||||
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
transform_train = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=image_size),
|
||||
transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
transform_test = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=image_size),
|
||||
transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
|
||||
std=[0.2353, 0.2219, 0.2325])])
|
||||
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
|
||||
std=[0.2353, 0.2219, 0.2325])])
|
||||
transform_train = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=image_size),
|
||||
transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.4425, 0.4695, 0.3266], std=[0.2353, 0.2219, 0.2325]
|
||||
),
|
||||
]
|
||||
)
|
||||
transform_test = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=image_size),
|
||||
transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.4425, 0.4695, 0.3266], std=[0.2353, 0.2219, 0.2325]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
trainset = Plantnet(root, 'train', transform=transform_train)
|
||||
trainset = Plantnet(root, "train", transform=transform_train)
|
||||
train_class_to_num_instances = Counter(trainset.targets)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=num_workers)
|
||||
trainloader = torch.utils.data.DataLoader(
|
||||
trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
||||
)
|
||||
|
||||
valset = Plantnet(root, 'val', transform=transform_test)
|
||||
valset = Plantnet(root, "val", transform=transform_test)
|
||||
|
||||
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=num_workers)
|
||||
valloader = torch.utils.data.DataLoader(
|
||||
valset, batch_size=batch_size, shuffle=True, num_workers=num_workers
|
||||
)
|
||||
|
||||
testset = Plantnet(root, 'test', transform=transform_test)
|
||||
testset = Plantnet(root, "test", transform=transform_test)
|
||||
test_class_to_num_instances = Counter(testset.targets)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
|
||||
shuffle=False, num_workers=num_workers)
|
||||
testloader = torch.utils.data.DataLoader(
|
||||
testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
||||
)
|
||||
|
||||
val_class_to_num_instances = Counter(valset.targets)
|
||||
n_classes = len(trainset.classes)
|
||||
|
||||
dataset_attributes = {'n_train': len(trainset), 'n_val': len(valset), 'n_test': len(testset), 'n_classes': n_classes,
|
||||
'class2num_instances': {'train': train_class_to_num_instances,
|
||||
'val': val_class_to_num_instances,
|
||||
'test': test_class_to_num_instances},
|
||||
'class_to_idx': trainset.class_to_idx}
|
||||
dataset_attributes = {
|
||||
"n_train": len(trainset),
|
||||
"n_val": len(valset),
|
||||
"n_test": len(testset),
|
||||
"n_classes": n_classes,
|
||||
"class2num_instances": {
|
||||
"train": train_class_to_num_instances,
|
||||
"val": val_class_to_num_instances,
|
||||
"test": test_class_to_num_instances,
|
||||
},
|
||||
"class_to_idx": trainset.class_to_idx,
|
||||
}
|
||||
|
||||
return trainloader, valloader, testloader, dataset_attributes
|
||||
|
|
Loading…
Reference in New Issue