1
0
Fork 0
PlantNetLibre-300K/train.py

293 lines
7.7 KiB
Python
Executable File

#!/usr/bin/env python3
#
# train.py
#
# SPDX-License-Identifier: BSD-2-Clause
#
# Copyright (c) 2023, Jeff Moe
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import argparse
import subprocess
# PyTorch version used in paper ?
# Python version used in paper ?
# Set hyperparameters based on values in Table 3 of
# Pl@ntNet-300K paper.
#
# Use defaults from git repo example.
# https://github.com/plantnet/PlantNet-300K
BATCH_SIZE = "32"
MU = "0.0001"
KONE = "1"
KTWO = "3"
KTHREE = "5"
KFOUR = "10"
SEED = "4"
IMAGE_SIZE = "256"
CROP_SIZE = "224"
# Root path to images test train val
ROOT_DIR = "/srv/ml/plantnet/files/plantnet_300K/images"
# Use GPU
USE_GPU = "1"
# Use all CPUs available on system. XXX get nproc
NUM_WORKERS = "12"
# Parse command line options
parser = argparse.ArgumentParser(
prog="train.py",
description="Train PlantNet-300K models using default parameters.",
epilog="Example: ./train.py alexnet",
)
parser.add_argument(
"model",
help="Model name",
type=str,
choices=[
"alexnet",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"inception_resnet_v2",
"inception_v3",
"inception_v4",
"mobilenet_v2",
"mobilenet_v3_large",
"mobilenet_v3_small",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"shufflenet_v2_x1_0",
"squeezenet1_0",
"vgg11",
"vit_b_16",
"wide_resnet50_2",
"wide_resnet101_2",
],
)
args = parser.parse_args()
MODEL_NAME = args.model
# Initial Learning Rate
LR = "N.NNNN"
# Number of Epochs
N_EPOCHS = "NN"
# First Decay
FIRST_DECAY = "1N"
# Secon Decay
SECOND_DECAY = "2N"
# Set LR, epochs, decay hyperparameters based on model.
match MODEL_NAME:
case "alexnet":
LR = "0.001"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "densenet121":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "densenet161":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "densenet169":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "densenet201":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "efficientnet_b0":
LR = "0.01"
N_EPOCHS = "20"
FIRST_DECAY = "10"
SECOND_DECAY = "15"
case "efficientnet_b1":
LR = "0.01"
N_EPOCHS = "20"
FIRST_DECAY = "10"
SECOND_DECAY = "15"
case "efficientnet_b2":
LR = "0.01"
N_EPOCHS = "20"
FIRST_DECAY = "10"
SECOND_DECAY = "15"
case "efficientnet_b3":
LR = "0.01"
N_EPOCHS = "20"
FIRST_DECAY = "10"
SECOND_DECAY = "15"
case "efficientnet_b4":
LR = "0.01"
N_EPOCHS = "20"
FIRST_DECAY = "10"
SECOND_DECAY = "15"
case "inception_resnet_v2":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "inception_v3":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "inception_v4":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "mobilenet_v2":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "mobilenet_v3_large":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "mobilenet_v3_small":
LR = "0.001"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "resnet18":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "resnet34":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "resnet50":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "resnet101":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "resnet152":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "shufflenet_v2_x1_0":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "squeezenet1_0":
LR = "0.001"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "vgg11":
LR = "0.001"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "vit_b_16":
LR = "0.0005"
N_EPOCHS = "20"
FIRST_DECAY = "15"
SECOND_DECAY = ""
case "wide_resnet50_2":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
case "wide_resnet101_2":
LR = "0.01"
N_EPOCHS = "30"
FIRST_DECAY = "20"
SECOND_DECAY = "25"
# Print command to run
print(
"python",
"main.py",
"--model=" + MODEL_NAME,
"--lr=" + LR,
"--n_epochs=" + N_EPOCHS,
"--epoch_decay " + str(FIRST_DECAY) + " " + str(SECOND_DECAY),
"--k " + str(KONE) + " " + str(KTWO) + " " + str(KTHREE) + " " + str(KFOUR),
"--batch_size=" + BATCH_SIZE,
"--mu=" + MU,
"--pretrained",
"--seed=" + SEED,
"--image_size=" + IMAGE_SIZE,
"--crop_size=" + CROP_SIZE,
"--root=" + ROOT_DIR,
"--use_gpu=" + USE_GPU,
"--num_workers=" + NUM_WORKERS,
"--save_name_xp=" + MODEL_NAME,
)
quit()
# Run model with corresponding hyperparameters
subprocess.run(
[
"python",
"main.py",
"--model=" + MODEL_NAME,
"--lr=" + LR,
"--n_epochs=" + N_EPOCHS,
"--epoch_decay=" + FIRST_DECAY + SECOND_DECAY,
"--k=" + str(KONE) + str(KTWO) + str(KTHREE) + str(KFOUR),
"--batch_size=" + BATCH_SIZE,
"--mu=" + MU,
"--pretrained",
"--seed=" + SEED,
"--image_size=" + IMAGE_SIZE,
"--crop_size=" + CROP_SIZE,
"--root=" + ROOT_DIR,
"--use_gpu=" + USE_GPU,
"--num_workers=" + NUM_WORKERS,
"--save_name_xp=" + MODEL_NAME,
]
)