1
0
Fork 0
PlantNetLibre-300K/utils.py

342 lines
12 KiB
Python

# SPDX-License-Identifier: BSD-2-Clause
#
# Copyright (c) 2021, Pl@ntNet
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
import torch.nn as nn
import random
import timm
import numpy as np
import os
from collections import Counter
from torchvision.models import (
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
inception_v3,
mobilenet_v2,
densenet121,
densenet161,
densenet169,
densenet201,
alexnet,
squeezenet1_0,
shufflenet_v2_x1_0,
wide_resnet50_2,
wide_resnet101_2,
vgg11,
mobilenet_v3_large,
mobilenet_v3_small,
)
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.transforms import CenterCrop
def set_seed(args, use_gpu, print_out=True):
if print_out:
print("Seed:\t {}".format(args.seed))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if use_gpu:
torch.cuda.manual_seed(args.seed)
def update_correct_per_class(batch_output, batch_y, d):
predicted_class = torch.argmax(batch_output, dim=-1)
for true_label, predicted_label in zip(batch_y, predicted_class):
if true_label == predicted_label:
d[true_label.item()] += 1
else:
d[true_label.item()] += 0
def update_correct_per_class_topk(batch_output, batch_y, d, k):
topk_labels_pred = torch.argsort(batch_output, axis=-1, descending=True)[:, :k]
for true_label, predicted_labels in zip(batch_y, topk_labels_pred):
d[true_label.item()] += torch.sum(true_label == predicted_labels).item()
def update_correct_per_class_avgk(val_probas, val_labels, d, lmbda):
ground_truth_probas = torch.gather(
val_probas, dim=1, index=val_labels.unsqueeze(-1)
)
for true_label, predicted_label in zip(val_labels, ground_truth_probas):
d[true_label.item()] += (predicted_label >= lmbda).item()
def count_correct_topk(scores, labels, k):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense).
"""
top_k_scores = torch.argsort(scores, axis=-1, descending=True)[:, :k]
labels = labels.view(len(labels), 1)
return torch.eq(labels, top_k_scores).sum()
def count_correct_avgk(probas, labels, lmbda):
"""Given a tensor of scores of size (n_batch, n_classes) and a tensor of
labels of size n_batch, computes the number of correctly predicted exemples
in the batch (in the top_k accuracy sense).
"""
gt_probas = torch.gather(probas, dim=1, index=labels.unsqueeze(-1))
res = torch.sum((gt_probas) >= lmbda)
return res
def load_model(model, filename, use_gpu):
if not os.path.exists(filename):
raise FileNotFoundError
device = "cuda:0" if use_gpu else "cpu"
d = torch.load(filename, map_location=device)
model.load_state_dict(d["model"])
return d["epoch"]
def load_optimizer(optimizer, filename, use_gpu):
if not os.path.exists(filename):
raise FileNotFoundError
device = "cuda:0" if use_gpu else "cpu"
d = torch.load(filename, map_location=device)
optimizer.load_state_dict(d["optimizer"])
def save(model, optimizer, epoch, location):
dir = os.path.dirname(location)
if not os.path.exists(dir):
os.makedirs(dir)
d = {
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(d, location)
def decay_lr(optimizer):
for param_group in optimizer.param_groups:
param_group["lr"] *= 0.1
print("Switching lr to {}".format(optimizer.param_groups[0]["lr"]))
return optimizer
def update_optimizer(optimizer, lr_schedule, epoch):
if epoch in lr_schedule:
optimizer = decay_lr(optimizer)
return optimizer
def get_model(args, n_classes):
pytorch_models = {
"resnet18": resnet18,
"resnet34": resnet34,
"resnet50": resnet50,
"resnet101": resnet101,
"resnet152": resnet152,
"densenet121": densenet121,
"densenet161": densenet161,
"densenet169": densenet169,
"densenet201": densenet201,
"mobilenet_v2": mobilenet_v2,
"inception_v3": inception_v3,
"alexnet": alexnet,
"squeezenet": squeezenet1_0,
"shufflenet": shufflenet_v2_x1_0,
"wide_resnet50_2": wide_resnet50_2,
"wide_resnet101_2": wide_resnet101_2,
"vgg11": vgg11,
"mobilenet_v3_large": mobilenet_v3_large,
"mobilenet_v3_small": mobilenet_v3_small,
}
timm_models = {
"inception_resnet_v2",
"inception_v4",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"vit_base_patch16_224",
}
if args.model in pytorch_models.keys() and not args.pretrained:
if args.model == "inception_v3":
model = pytorch_models[args.model](
pretrained=False, num_classes=n_classes, aux_logits=False
)
else:
model = pytorch_models[args.model](pretrained=False, num_classes=n_classes)
elif args.model in pytorch_models.keys() and args.pretrained:
if args.model in {
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"wide_resnet50_2",
"wide_resnet101_2",
"shufflenet",
}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
elif args.model in {"alexnet", "vgg11"}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, n_classes)
elif args.model in {"densenet121", "densenet161", "densenet169", "densenet201"}:
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, n_classes)
elif args.model == "mobilenet_v2":
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, n_classes)
elif args.model == "inception_v3":
model = inception_v3(pretrained=True, aux_logits=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, n_classes)
elif args.model == "squeezenet":
model = pytorch_models[args.model](pretrained=True)
model.classifier[1] = nn.Conv2d(
512, n_classes, kernel_size=(1, 1), stride=(1, 1)
)
model.num_classes = n_classes
elif args.model == "mobilenet_v3_large" or args.model == "mobilenet_v3_small":
model = pytorch_models[args.model](pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, n_classes)
elif args.model in timm_models:
model = timm.create_model(
args.model, pretrained=args.pretrained, num_classes=n_classes
)
else:
raise NotImplementedError
return model
class Plantnet(ImageFolder):
def __init__(self, root, split, **kwargs):
self.root = root
self.split = split
super().__init__(self.split_folder, **kwargs)
@property
def split_folder(self):
return os.path.join(self.root, self.split)
def get_data(root, image_size, crop_size, batch_size, num_workers, pretrained):
if pretrained:
transform_train = transforms.Compose(
[
transforms.Resize(size=image_size),
transforms.RandomCrop(size=crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
transform_test = transforms.Compose(
[
transforms.Resize(size=image_size),
transforms.CenterCrop(size=crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
else:
transform_train = transforms.Compose(
[
transforms.Resize(size=image_size),
transforms.RandomCrop(size=crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4425, 0.4695, 0.3266], std=[0.2353, 0.2219, 0.2325]
),
]
)
transform_test = transforms.Compose(
[
transforms.Resize(size=image_size),
transforms.CenterCrop(size=crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4425, 0.4695, 0.3266], std=[0.2353, 0.2219, 0.2325]
),
]
)
trainset = Plantnet(root, "train", transform=transform_train)
train_class_to_num_instances = Counter(trainset.targets)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
valset = Plantnet(root, "val", transform=transform_test)
valloader = torch.utils.data.DataLoader(
valset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
testset = Plantnet(root, "test", transform=transform_test)
test_class_to_num_instances = Counter(testset.targets)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
val_class_to_num_instances = Counter(valset.targets)
n_classes = len(trainset.classes)
dataset_attributes = {
"n_train": len(trainset),
"n_val": len(valset),
"n_test": len(testset),
"n_classes": n_classes,
"class2num_instances": {
"train": train_class_to_num_instances,
"val": val_class_to_num_instances,
"test": test_class_to_num_instances,
},
"class_to_idx": trainset.class_to_idx,
}
return trainloader, valloader, testloader, dataset_attributes