1
0
Fork 0

adding code for all models

funding
camille garcin 2022-08-02 11:53:04 +02:00
parent 70ce402fb9
commit 224544b12c
4 changed files with 97 additions and 38 deletions

24
cli.py
View File

@ -7,26 +7,36 @@ def add_all_parsers(parser):
_add_model_parser(parser)
_add_hardware_parser(parser)
_add_misc_parser(parser)
_add_dataset_parser(parser)
def _add_loss_parser(parser):
group_loss = parser.add_argument_group('Loss parameters')
group_loss.add_argument('--mu', type=float, default=0., help='weight decay parameter')
group_loss.add_argument('--mu', type=float, default=0.0001, help='weight decay parameter')
def _add_training_parser(parser):
group_training = parser.add_argument_group('Training parameters')
group_training.add_argument('--lr', type=float, help='learning rate to use')
group_training.add_argument('--batch_size', type=int, default=256, help='default is 256')
group_training.add_argument('--batch_size', type=int, default=32, help='default is 32')
group_training.add_argument('--n_epochs', type=int)
group_training.add_argument('--pretrained', action='store_true')
group_training.add_argument('--image_size', type=int, default=256)
group_training.add_argument('--crop_size', type=int, default=224)
group_training.add_argument('--epoch_decay', nargs='+', type=int, default=[])
group_training.add_argument('--k', nargs='+', help='value of k for computing the topk loss and computing topk accuracy',
required=True, type=int)
def _add_model_parser(parser):
group_model = parser.add_argument_group('Model parameters')
group_model.add_argument('--model', choices=['resnet50', 'densenet121', 'densenet169', 'mobilenet_v2', 'inception_resnetv2'],
group_model.add_argument('--model', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'densenet121', 'densenet161', 'densenet169', 'densenet201',
'mobilenet_v2', 'inception_v3', 'alexnet', 'squeezenet',
'shufflenet', 'wide_resnet50_2', 'wide_resnet101_2',
'vgg11', 'mobilenet_v3_large', 'mobilenet_v3_small',
'inception_resnet_v2', 'inception_v4', 'efficientnet_b0',
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
'efficientnet_b4', 'vit_base_patch16_224'],
default='resnet50', help='choose the model you want to train on')
@ -35,12 +45,6 @@ def _add_hardware_parser(parser):
group_hardware.add_argument('--use_gpu', type=int, choices=[0, 1], default=torch.cuda.is_available())
def _add_dataset_parser(parser):
group_dataset = parser.add_argument_group('Dataset parameters')
group_dataset.add_argument('--size_image', type=int, default=256,
help='size you want to resize the images to')
def _add_misc_parser(parser):
group_misc = parser.add_argument_group('Miscellaneous parameters')
group_misc.add_argument('--seed', type=int, help='set the seed for reproductible experiments')

View File

@ -14,12 +14,14 @@ from cli import add_all_parsers
def train(args):
set_seed(args, use_gpu=torch.cuda.is_available())
train_loader, val_loader, test_loader, dataset_attributes = get_data(args)
train_loader, val_loader, test_loader, dataset_attributes = get_data(args.root, args.image_size, args.crop_size,
args.batch_size, args.num_workers, args.pretrained)
model = get_model(args, n_classes=dataset_attributes['n_classes'])
criteria = CrossEntropyLoss()
if args.use_gpu:
print('USING GPU')
torch.cuda.set_device(0)
model.cuda()
criteria.cuda()
@ -42,7 +44,7 @@ def train(args):
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
t = time.time()
optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch)
optimizer = update_optimizer(optimizer, lr_schedule=args.epoch_decay, epoch=epoch)
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
criteria, loss_train, acc_train,

View File

@ -4,8 +4,8 @@ channels:
dependencies:
- tqdm
- pytorch=1.4.0
- torchvision=0.5.0
- pytorch
- torchvision
- pip:
- timm

101
utils.py
View File

@ -1,4 +1,5 @@
import torch
import torch.nn as nn
import random
import timm
import numpy as np
@ -6,7 +7,9 @@ import os
from collections import Counter
from torchvision.models import resnet50, densenet121, densenet169, mobilenet_v2
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, inception_v3, mobilenet_v2, densenet121, \
densenet161, densenet169, densenet201, alexnet, squeezenet1_0, shufflenet_v2_x1_0, wide_resnet50_2, wide_resnet101_2,\
vgg11, mobilenet_v3_large, mobilenet_v3_small
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
@ -108,12 +111,57 @@ def update_optimizer(optimizer, lr_schedule, epoch):
def get_model(args, n_classes):
model_dict = {'resnet50': resnet50, 'densenet121': densenet121, 'densenet169': densenet169, 'mobilenet_v2': mobilenet_v2}
pytorch_models = {'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50, 'resnet101': resnet101,
'resnet152': resnet152, 'densenet121': densenet121, 'densenet161': densenet161,
'densenet169': densenet169, 'densenet201': densenet201, 'mobilenet_v2': mobilenet_v2,
'inception_v3': inception_v3, 'alexnet': alexnet, 'squeezenet': squeezenet1_0,
'shufflenet': shufflenet_v2_x1_0, 'wide_resnet50_2': wide_resnet50_2,
'wide_resnet101_2': wide_resnet101_2, 'vgg11': vgg11, 'mobilenet_v3_large': mobilenet_v3_large,
'mobilenet_v3_small': mobilenet_v3_small
}
timm_models = {'inception_resnet_v2', 'inception_v4', 'efficientnet_b0', 'efficientnet_b1',
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'vit_base_patch16_224'}
if args.model == 'inception_resnetv2':
model = timm.create_model('inception_resnet_v2', pretrained=False, num_classes=n_classes)
if args.model in pytorch_models.keys() and not args.pretrained:
if args.model == 'inception_v3':
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes, aux_logits=False)
else:
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes)
elif args.model in pytorch_models.keys() and args.pretrained:
if args.model in {'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2',
'wide_resnet101_2', 'shufflenet'}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
elif args.model in {'alexnet', 'vgg11'}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, n_classes)
elif args.model in {'densenet121', 'densenet161', 'densenet169', 'densenet201'}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, n_classes)
elif args.model == 'mobilenet_v2':
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, n_classes)
elif args.model == 'inception_v3':
model = inception_v3(pretrained=True, aux_logits=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
elif args.model == 'squeezenet':
model = pytorch_models[args.model](pretrained=True)
model.classifier[1] = nn.Conv2d(512, n_classes, kernel_size=(1, 1), stride=(1, 1))
model.num_classes = n_classes
elif args.model == 'mobilenet_v3_large' or args.model == 'mobilenet_v3_small':
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, n_classes)
elif args.model in timm_models:
model = timm.create_model(args.model, pretrained=args.pretrained, num_classes=n_classes)
else:
model = model_dict[args.model](pretrained=False, num_classes=n_classes)
raise NotImplementedError
return model
@ -129,37 +177,42 @@ class Plantnet(ImageFolder):
return os.path.join(self.root, self.split)
class MaxCenterCrop:
def __call__(self, sample):
min_size = min(sample.size[0], sample.size[1])
return CenterCrop(min_size)(sample)
def get_data(root, image_size, crop_size, batch_size, num_workers, pretrained):
if pretrained:
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
std=[0.2353, 0.2219, 0.2325])])
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
std=[0.2353, 0.2219, 0.2325])])
def get_data(args):
transform = transforms.Compose(
[MaxCenterCrop(), transforms.Resize(args.size_image), transforms.ToTensor()])
trainset = Plantnet(args.root, 'images_train', transform=transform)
trainset = Plantnet(root, 'train', transform=transform_train)
train_class_to_num_instances = Counter(trainset.targets)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
valset = Plantnet(args.root, 'images_val', transform=transform)
valset = Plantnet(root, 'val', transform=transform_test)
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
testset = Plantnet(args.root, 'images_test', transform=transform)
testset = Plantnet(root, 'test', transform=transform_test)
test_class_to_num_instances = Counter(testset.targets)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
val_class_to_num_instances = Counter(valset.targets)
n_classes = len(trainset.classes)
dataset_attributes = {'n_train': len(trainset), 'n_val': len(valset), 'n_test': len(testset), 'n_classes': n_classes,
'lr_schedule': [40, 50, 60],
'class2num_instances': {'train': train_class_to_num_instances,
'val': val_class_to_num_instances,
'test': test_class_to_num_instances},