diff --git a/cli.py b/cli.py index f1b95ea..c41bb86 100644 --- a/cli.py +++ b/cli.py @@ -7,26 +7,36 @@ def add_all_parsers(parser): _add_model_parser(parser) _add_hardware_parser(parser) _add_misc_parser(parser) - _add_dataset_parser(parser) def _add_loss_parser(parser): group_loss = parser.add_argument_group('Loss parameters') - group_loss.add_argument('--mu', type=float, default=0., help='weight decay parameter') + group_loss.add_argument('--mu', type=float, default=0.0001, help='weight decay parameter') def _add_training_parser(parser): group_training = parser.add_argument_group('Training parameters') group_training.add_argument('--lr', type=float, help='learning rate to use') - group_training.add_argument('--batch_size', type=int, default=256, help='default is 256') + group_training.add_argument('--batch_size', type=int, default=32, help='default is 32') group_training.add_argument('--n_epochs', type=int) + group_training.add_argument('--pretrained', action='store_true') + group_training.add_argument('--image_size', type=int, default=256) + group_training.add_argument('--crop_size', type=int, default=224) + group_training.add_argument('--epoch_decay', nargs='+', type=int, default=[]) group_training.add_argument('--k', nargs='+', help='value of k for computing the topk loss and computing topk accuracy', required=True, type=int) def _add_model_parser(parser): group_model = parser.add_argument_group('Model parameters') - group_model.add_argument('--model', choices=['resnet50', 'densenet121', 'densenet169', 'mobilenet_v2', 'inception_resnetv2'], + group_model.add_argument('--model', choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', + 'densenet121', 'densenet161', 'densenet169', 'densenet201', + 'mobilenet_v2', 'inception_v3', 'alexnet', 'squeezenet', + 'shufflenet', 'wide_resnet50_2', 'wide_resnet101_2', + 'vgg11', 'mobilenet_v3_large', 'mobilenet_v3_small', + 'inception_resnet_v2', 'inception_v4', 'efficientnet_b0', + 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'vit_base_patch16_224'], default='resnet50', help='choose the model you want to train on') @@ -35,12 +45,6 @@ def _add_hardware_parser(parser): group_hardware.add_argument('--use_gpu', type=int, choices=[0, 1], default=torch.cuda.is_available()) -def _add_dataset_parser(parser): - group_dataset = parser.add_argument_group('Dataset parameters') - group_dataset.add_argument('--size_image', type=int, default=256, - help='size you want to resize the images to') - - def _add_misc_parser(parser): group_misc = parser.add_argument_group('Miscellaneous parameters') group_misc.add_argument('--seed', type=int, help='set the seed for reproductible experiments') diff --git a/main.py b/main.py index c5e75b4..fac959b 100644 --- a/main.py +++ b/main.py @@ -14,12 +14,14 @@ from cli import add_all_parsers def train(args): set_seed(args, use_gpu=torch.cuda.is_available()) - train_loader, val_loader, test_loader, dataset_attributes = get_data(args) + train_loader, val_loader, test_loader, dataset_attributes = get_data(args.root, args.image_size, args.crop_size, + args.batch_size, args.num_workers, args.pretrained) model = get_model(args, n_classes=dataset_attributes['n_classes']) criteria = CrossEntropyLoss() if args.use_gpu: + print('USING GPU') torch.cuda.set_device(0) model.cuda() criteria.cuda() @@ -42,7 +44,7 @@ def train(args): for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0): t = time.time() - optimizer = update_optimizer(optimizer, lr_schedule=dataset_attributes['lr_schedule'], epoch=epoch) + optimizer = update_optimizer(optimizer, lr_schedule=args.epoch_decay, epoch=epoch) loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader, criteria, loss_train, acc_train, diff --git a/plantnet_300k_env.yml b/plantnet_300k_env.yml index c3439c4..e152849 100644 --- a/plantnet_300k_env.yml +++ b/plantnet_300k_env.yml @@ -4,8 +4,8 @@ channels: dependencies: - tqdm - - pytorch=1.4.0 - - torchvision=0.5.0 + - pytorch + - torchvision - pip: - timm diff --git a/utils.py b/utils.py index 09baef6..f29222c 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn import random import timm import numpy as np @@ -6,7 +7,9 @@ import os from collections import Counter -from torchvision.models import resnet50, densenet121, densenet169, mobilenet_v2 +from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, inception_v3, mobilenet_v2, densenet121, \ + densenet161, densenet169, densenet201, alexnet, squeezenet1_0, shufflenet_v2_x1_0, wide_resnet50_2, wide_resnet101_2,\ + vgg11, mobilenet_v3_large, mobilenet_v3_small from torchvision.datasets import ImageFolder import torchvision.transforms as transforms @@ -108,12 +111,57 @@ def update_optimizer(optimizer, lr_schedule, epoch): def get_model(args, n_classes): - model_dict = {'resnet50': resnet50, 'densenet121': densenet121, 'densenet169': densenet169, 'mobilenet_v2': mobilenet_v2} + pytorch_models = {'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50, 'resnet101': resnet101, + 'resnet152': resnet152, 'densenet121': densenet121, 'densenet161': densenet161, + 'densenet169': densenet169, 'densenet201': densenet201, 'mobilenet_v2': mobilenet_v2, + 'inception_v3': inception_v3, 'alexnet': alexnet, 'squeezenet': squeezenet1_0, + 'shufflenet': shufflenet_v2_x1_0, 'wide_resnet50_2': wide_resnet50_2, + 'wide_resnet101_2': wide_resnet101_2, 'vgg11': vgg11, 'mobilenet_v3_large': mobilenet_v3_large, + 'mobilenet_v3_small': mobilenet_v3_small + } + timm_models = {'inception_resnet_v2', 'inception_v4', 'efficientnet_b0', 'efficientnet_b1', + 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'vit_base_patch16_224'} - if args.model == 'inception_resnetv2': - model = timm.create_model('inception_resnet_v2', pretrained=False, num_classes=n_classes) + if args.model in pytorch_models.keys() and not args.pretrained: + if args.model == 'inception_v3': + model = pytorch_models[args.model](pretrained=False, num_classes=n_classes, aux_logits=False) + else: + model = pytorch_models[args.model](pretrained=False, num_classes=n_classes) + elif args.model in pytorch_models.keys() and args.pretrained: + if args.model in {'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'wide_resnet50_2', + 'wide_resnet101_2', 'shufflenet'}: + model = pytorch_models[args.model](pretrained=True) + num_ftrs = model.fc.in_features + model.fc = nn.Linear(num_ftrs, n_classes) + elif args.model in {'alexnet', 'vgg11'}: + model = pytorch_models[args.model](pretrained=True) + num_ftrs = model.classifier[6].in_features + model.classifier[6] = nn.Linear(num_ftrs, n_classes) + elif args.model in {'densenet121', 'densenet161', 'densenet169', 'densenet201'}: + model = pytorch_models[args.model](pretrained=True) + num_ftrs = model.classifier.in_features + model.classifier = nn.Linear(num_ftrs, n_classes) + elif args.model == 'mobilenet_v2': + model = pytorch_models[args.model](pretrained=True) + num_ftrs = model.classifier[1].in_features + model.classifier[1] = nn.Linear(num_ftrs, n_classes) + elif args.model == 'inception_v3': + model = inception_v3(pretrained=True, aux_logits=False) + num_ftrs = model.fc.in_features + model.fc = nn.Linear(num_ftrs, n_classes) + elif args.model == 'squeezenet': + model = pytorch_models[args.model](pretrained=True) + model.classifier[1] = nn.Conv2d(512, n_classes, kernel_size=(1, 1), stride=(1, 1)) + model.num_classes = n_classes + elif args.model == 'mobilenet_v3_large' or args.model == 'mobilenet_v3_small': + model = pytorch_models[args.model](pretrained=True) + num_ftrs = model.classifier[-1].in_features + model.classifier[-1] = nn.Linear(num_ftrs, n_classes) + + elif args.model in timm_models: + model = timm.create_model(args.model, pretrained=args.pretrained, num_classes=n_classes) else: - model = model_dict[args.model](pretrained=False, num_classes=n_classes) + raise NotImplementedError return model @@ -129,37 +177,42 @@ class Plantnet(ImageFolder): return os.path.join(self.root, self.split) -class MaxCenterCrop: - def __call__(self, sample): - min_size = min(sample.size[0], sample.size[1]) - return CenterCrop(min_size)(sample) +def get_data(root, image_size, crop_size, batch_size, num_workers, pretrained): + if pretrained: + transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size), + transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size), + transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + else: + transform_train = transforms.Compose([transforms.Resize(size=image_size), transforms.RandomCrop(size=crop_size), + transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266], + std=[0.2353, 0.2219, 0.2325])]) + transform_test = transforms.Compose([transforms.Resize(size=image_size), transforms.CenterCrop(size=crop_size), + transforms.ToTensor(), transforms.Normalize(mean=[0.4425, 0.4695, 0.3266], + std=[0.2353, 0.2219, 0.2325])]) -def get_data(args): - - transform = transforms.Compose( - [MaxCenterCrop(), transforms.Resize(args.size_image), transforms.ToTensor()]) - - trainset = Plantnet(args.root, 'images_train', transform=transform) + trainset = Plantnet(root, 'train', transform=transform_train) train_class_to_num_instances = Counter(trainset.targets) - trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, + shuffle=True, num_workers=num_workers) - valset = Plantnet(args.root, 'images_val', transform=transform) + valset = Plantnet(root, 'val', transform=transform_test) - valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers) + valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, + shuffle=True, num_workers=num_workers) - testset = Plantnet(args.root, 'images_test', transform=transform) + testset = Plantnet(root, 'test', transform=transform_test) test_class_to_num_instances = Counter(testset.targets) - testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, - shuffle=False, num_workers=args.num_workers) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, + shuffle=False, num_workers=num_workers) val_class_to_num_instances = Counter(valset.targets) n_classes = len(trainset.classes) dataset_attributes = {'n_train': len(trainset), 'n_val': len(valset), 'n_test': len(testset), 'n_classes': n_classes, - 'lr_schedule': [40, 50, 60], 'class2num_instances': {'train': train_class_to_num_instances, 'val': val_class_to_num_instances, 'test': test_class_to_num_instances},