adding code for all models
parent
70ce402fb9
commit
224544b12c
24
cli.py
24
cli.py
|
@ -7,26 +7,36 @@ def add_all_parsers(parser):
|
|||
_add_model_parser(parser)
|
||||
_add_hardware_parser(parser)
|
||||
_add_misc_parser(parser)
|
||||
_add_dataset_parser(parser)
|
||||
|
||||
|
||||
def _add_loss_parser(parser):
|
||||
group_loss = parser.add_argument_group('Loss parameters')
|
||||
group_loss.add_argument('--mu', type=float, default=0., help='weight decay parameter')
|
||||
group_loss.add_argument('--mu', type=float, default=0.0001, help='weight decay parameter')
|
||||
|
||||
|
||||
def _add_training_parser(parser):
|
||||
group_training = parser.add_argument_group('Training parameters')
|
||||
group_training.add_argument('--lr', type=float, help='learning rate to use')
|
||||
group_training.add_argument('--batch_size', type=int, default=256, help='default is 256')
|
||||
group_training.add_argument('--batch_size', type=int, default=32, help='default is 32')
|
||||
group_training.add_argument('--n_epochs', type=int)
|
||||
group_training.add_argument('--pretrained', action='store_true')
|
||||
group_training.add_argument('--image_size', type=int, default=256)
|
||||
group_training.add_argument('--crop_size', type=int, default=224)
|
||||
group_training.add_argument('--epoch_decay', nargs='+', type=int, default=[])
|
||||
group_training.add_argument('--k', nargs='+', help='value of k for computing the topk loss and computing topk accuracy',
|
||||
required=True, type=int)
|
||||
|
||||
|
||||
def _add_model_parser(parser):
|
||||
group_model = parser.add_argument_group('Model parameters')
|
||||
group_model.add_argument('--model', choices=['resnet50', 'densenet121', 'densenet169', 'mobilenet_v2', 'inception_resnetv2'],
|
||||
group_model.add_argument('--model', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'densenet121', 'densenet161', 'densenet169', 'densenet201',
|
||||
'mobilenet_v2', 'inception_v3', 'alexnet', 'squeezenet',
|
||||
'shufflenet', 'wide_resnet50_2', 'wide_resnet101_2',
|
||||
'vgg11', 'mobilenet_v3_large', 'mobilenet_v3_small',
|
||||
'inception_resnet_v2', 'inception_v4', 'efficientnet_b0',
|
||||
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
|
||||
'efficientnet_b4', 'vit_base_patch16_224'],
|
||||
default='resnet50', help='choose the model you want to train on')
|
||||
|
||||
|
||||
|
@ -35,12 +45,6 @@ def _add_hardware_parser(parser):
|
|||
group_hardware.add_argument('--use_gpu', type=int, choices=[0, 1], default=torch.cuda.is_available())
|
||||
|
||||
|
||||
def _add_dataset_parser(parser):
|
||||
group_dataset = parser.add_argument_group('Dataset parameters')
|
||||
group_dataset.add_argument('--size_image', type=int, default=256,
|
||||
help='size you want to resize the images to')
|
||||
|
||||
|
||||
def _add_misc_parser(parser):
|
||||
group_misc = parser.add_argument_group('Miscellaneous parameters')
|
||||
group_misc.add_argument('--seed', type=int, help='set the seed for reproductible experiments')
|
||||
|
|
6
main.py
6
main.py
|
@ -14,12 +14,14 @@ from cli import add_all_parsers
|
|||
|
||||
def train(args):
|
||||
set_seed(args, use_gpu=torch.cuda.is_available())
|
||||
train_loader, val_loader, test_loader, dataset_attributes = get_data(args)
|
||||
train_loader, val_loader, test_loader, dataset_attributes = get_data(args.root, args.image_size, args.crop_size,
|
||||
args.batch_size, args.num_workers, args.pretrained)
|
||||
|
||||
model = get_model(args, n_classes=dataset_attributes['n_classes'])
|
||||
criteria = CrossEntropyLoss()
|
||||
|
||||
if args.use_gpu:
|
||||
print('USING GPU')
|
||||
torch.cuda.set_device(0)
|
||||
model.cuda()
|
||||
criteria.cuda()
|
||||
|
@ -42,7 +44,7 @@ def train(args):
|
|||
|
||||
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
|
||||
t = time.time()
|
||||
optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch)
|
||||
optimizer = update_optimizer(optimizer, lr_schedule=args.epoch_decay, epoch=epoch)
|
||||
|
||||
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
|
||||
criteria, loss_train, acc_train,
|
||||
|
|
|
@ -4,8 +4,8 @@ channels:
|
|||
|
||||
dependencies:
|
||||
- tqdm
|
||||
- pytorch=1.4.0
|
||||
- torchvision=0.5.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- pip:
|
||||
- timm
|
||||
|
||||
|
|
101
utils.py
101
utils.py
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import timm
|
||||
import numpy as np
|
||||
|
@ -6,7 +7,9 @@ import os
|
|||
from collections import Counter
|
||||
|
||||
|
||||
from torchvision.models import resnet50, densenet121, densenet169, mobilenet_v2
|
||||
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, inception_v3, mobilenet_v2, densenet121, \
|
||||
densenet161, densenet169, densenet201, alexnet, squeezenet1_0, shufflenet_v2_x1_0, wide_resnet50_2, wide_resnet101_2,\
|
||||
vgg11, mobilenet_v3_large, mobilenet_v3_small
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
@ -108,12 +111,57 @@ def update_optimizer(optimizer, lr_schedule, epoch):
|
|||
|
||||
|
||||
def get_model(args, n_classes):
|
||||
model_dict = {'resnet50': resnet50, 'densenet121': densenet121, 'densenet169': densenet169, 'mobilenet_v2': mobilenet_v2}
|
||||
pytorch_models = {'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50, 'resnet101': resnet101,
|
||||
'resnet152': resnet152, 'densenet121': densenet121, 'densenet161': densenet161,
|
||||
'densenet169': densenet169, 'densenet201': densenet201, 'mobilenet_v2': mobilenet_v2,
|
||||
'inception_v3': inception_v3, 'alexnet': alexnet, 'squeezenet': squeezenet1_0,
|
||||
'shufflenet': shufflenet_v2_x1_0, 'wide_resnet50_2': wide_resnet50_2,
|
||||
'wide_resnet101_2': wide_resnet101_2, 'vgg11': vgg11, 'mobilenet_v3_large': mobilenet_v3_large,
|
||||
'mobilenet_v3_small': mobilenet_v3_small
|
||||
}
|
||||
timm_models = {'inception_resnet_v2', 'inception_v4', 'efficientnet_b0', 'efficientnet_b1',
|
||||
'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'vit_base_patch16_224'}
|
||||
|
||||
if args.model == 'inception_resnetv2':
|
||||
model = timm.create_model('inception_resnet_v2', pretrained=False, num_classes=n_classes)
|
||||
if args.model in pytorch_models.keys() and not args.pretrained:
|
||||
if args.model == 'inception_v3':
|
||||
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes, aux_logits=False)
|
||||
else:
|
||||
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes)
|
||||
elif args.model in pytorch_models.keys() and args.pretrained:
|
||||
if args.model in {'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2',
|
||||
'wide_resnet101_2', 'shufflenet'}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model in {'alexnet', 'vgg11'}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[6].in_features
|
||||
model.classifier[6] = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model in {'densenet121', 'densenet161', 'densenet169', 'densenet201'}:
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'mobilenet_v2':
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[1].in_features
|
||||
model.classifier[1] = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'inception_v3':
|
||||
model = inception_v3(pretrained=True, aux_logits=False)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, n_classes)
|
||||
elif args.model == 'squeezenet':
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
model.classifier[1] = nn.Conv2d(512, n_classes, kernel_size=(1, 1), stride=(1, 1))
|
||||
model.num_classes = n_classes
|
||||
elif args.model == 'mobilenet_v3_large' or args.model == 'mobilenet_v3_small':
|
||||
model = pytorch_models[args.model](pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, n_classes)
|
||||
|
||||
elif args.model in timm_models:
|
||||
model = timm.create_model(args.model, pretrained=args.pretrained, num_classes=n_classes)
|
||||
else:
|
||||
model = model_dict[args.model](pretrained=False, num_classes=n_classes)
|
||||
raise NotImplementedError
|
||||
|
||||
return model
|
||||
|
||||
|
@ -129,37 +177,42 @@ class Plantnet(ImageFolder):
|
|||
return os.path.join(self.root, self.split)
|
||||
|
||||
|
||||
class MaxCenterCrop:
|
||||
def __call__(self, sample):
|
||||
min_size = min(sample.size[0], sample.size[1])
|
||||
return CenterCrop(min_size)(sample)
|
||||
def get_data(root, image_size, crop_size, batch_size, num_workers, pretrained):
|
||||
|
||||
if pretrained:
|
||||
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])])
|
||||
else:
|
||||
transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
|
||||
std=[0.2353, 0.2219, 0.2325])])
|
||||
transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size),
|
||||
transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266],
|
||||
std=[0.2353, 0.2219, 0.2325])])
|
||||
|
||||
def get_data(args):
|
||||
|
||||
transform = transforms.Compose(
|
||||
[MaxCenterCrop(), transforms.Resize(args.size_image), transforms.ToTensor()])
|
||||
|
||||
trainset = Plantnet(args.root, 'images_train', transform=transform)
|
||||
trainset = Plantnet(root, 'train', transform=transform_train)
|
||||
train_class_to_num_instances = Counter(trainset.targets)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
|
||||
shuffle=True, num_workers=args.num_workers)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=num_workers)
|
||||
|
||||
valset = Plantnet(args.root, 'images_val', transform=transform)
|
||||
valset = Plantnet(root, 'val', transform=transform_test)
|
||||
|
||||
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
|
||||
shuffle=True, num_workers=args.num_workers)
|
||||
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
|
||||
shuffle=True, num_workers=num_workers)
|
||||
|
||||
testset = Plantnet(args.root, 'images_test', transform=transform)
|
||||
testset = Plantnet(root, 'test', transform=transform_test)
|
||||
test_class_to_num_instances = Counter(testset.targets)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
|
||||
shuffle=False, num_workers=args.num_workers)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
|
||||
shuffle=False, num_workers=num_workers)
|
||||
|
||||
val_class_to_num_instances = Counter(valset.targets)
|
||||
n_classes = len(trainset.classes)
|
||||
|
||||
dataset_attributes = {'n_train': len(trainset), 'n_val': len(valset), 'n_test': len(testset), 'n_classes': n_classes,
|
||||
'lr_schedule': [40, 50, 60],
|
||||
'class2num_instances': {'train': train_class_to_num_instances,
|
||||
'val': val_class_to_num_instances,
|
||||
'test': test_class_to_num_instances},
|
||||
|
|
Loading…
Reference in New Issue