diff --git a/load-model.py b/load-model.py index 75d5605..bf00db5 100755 --- a/load-model.py +++ b/load-model.py @@ -44,9 +44,11 @@ from torchvision.models import efficientnet_b1 from torchvision.models import efficientnet_b2 from torchvision.models import efficientnet_b3 from torchvision.models import efficientnet_b4 -#from torchvision.models import inception_resnet_v2 + +# from torchvision.models import inception_resnet_v2 from torchvision.models import inception_v3 -#from torchvision.models import inception_v4 + +# from torchvision.models import inception_v4 from torchvision.models import mobilenet_v2 from torchvision.models import mobilenet_v3_large from torchvision.models import mobilenet_v3_small @@ -56,7 +58,8 @@ from torchvision.models import resnet50 from torchvision.models import resnet101 from torchvision.models import resnet152 from torchvision.models import shufflenet_v2_x1_0 -#from torchvision.models import squeezenet + +# from torchvision.models import squeezenet from torchvision.models import squeezenet1_0 from torchvision.models import vgg11 from torchvision.models import vit_b_16 @@ -66,144 +69,143 @@ from torchvision.models import wide_resnet101_2 use_gpu = True ### BEGIN upstream OK ### -#filename = '/srv/ml/plantnet/models/resnet18_weights_best_acc.tar' -#model = resnet18(num_classes=1081) # 1081 classes in Pl@ntNet-300K +# filename = '/srv/ml/plantnet/models/resnet18_weights_best_acc.tar' +# model = resnet18(num_classes=1081) # 1081 classes in Pl@ntNet-300K ### END upstream ### ### BEGIN alexnet OK ### -filename = '/srv/ml/deepcrayon/plantnet/models/alexnet_weights_best_acc.tar' +filename = "/srv/ml/deepcrayon/plantnet/models/alexnet_weights_best_acc.tar" model = alexnet(num_classes=1081) ### END alexnet ### ### BEGIN densenet121 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/densenet121_weights_best_acc.tar' -#model = densenet121(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/densenet121_weights_best_acc.tar' +# model = densenet121(num_classes=1081) ### END densenet121 ### ### BEGIN densenet161 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/densenet161_weights_best_acc.tar' -#model = densenet161(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/densenet161_weights_best_acc.tar' +# model = densenet161(num_classes=1081) ### END densenet161 ### ### BEGIN densenet169 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/densenet169_weights_best_acc.tar' -#model = densenet169(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/densenet169_weights_best_acc.tar' +# model = densenet169(num_classes=1081) ### END densenet169 ### ### BEGIN densenet201 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/densenet201_weights_best_acc.tar' -#model = densenet201(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/densenet201_weights_best_acc.tar' +# model = densenet201(num_classes=1081) ### END densenet201 ### ### BEGIN efficientnet_b0 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b0_weights_best_acc.tar' -#model = efficientnet_b0(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b0_weights_best_acc.tar' +# model = efficientnet_b0(num_classes=1081) ### END efficientnet_b0 ### ### BEGIN efficientnet_b1 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b1_weights_best_acc.tar' -#model = efficientnet_b1(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b1_weights_best_acc.tar' +# model = efficientnet_b1(num_classes=1081) ### END efficientnet_b1 ### ### BEGIN efficientnet_b2 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b2_weights_best_acc.tar' -#model = efficientnet_b2(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b2_weights_best_acc.tar' +# model = efficientnet_b2(num_classes=1081) ### END efficientnet_b2 ### ### BEGIN efficientnet_b3 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b3_weights_best_acc.tar' -#model = efficientnet_b3(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b3_weights_best_acc.tar' +# model = efficientnet_b3(num_classes=1081) ### END efficientnet_b3 ### ### BEGIN efficientnet_b4 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b4_weights_best_acc.tar' -#model = efficientnet_b4(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/efficientnet_b4_weights_best_acc.tar' +# model = efficientnet_b4(num_classes=1081) ### END efficientnet_b4 ### ### BEGIN inception_resnet_v2 FAIL no module import ### -#filename = '/srv/ml/deepcrayon/plantnet/models/inception_resnet_v2_weights_best_acc.tar' -#model = inception_resnet_v2(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/inception_resnet_v2_weights_best_acc.tar' +# model = inception_resnet_v2(num_classes=1081) ### END inception_resnet_v2 ### ### BEGIN inception_v3 FAIL no train ### -#filename = '/srv/ml/deepcrayon/plantnet/models/inception_v3_weights_best_acc.tar' -#model = inception_v3(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/inception_v3_weights_best_acc.tar' +# model = inception_v3(num_classes=1081) ### END inception_v3 ### ### BEGIN inception_v4 FAIL no module import ### -#filename = '/srv/ml/deepcrayon/plantnet/models/inception_v4_weights_best_acc.tar' -#model = inception_v4(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/inception_v4_weights_best_acc.tar' +# model = inception_v4(num_classes=1081) ### END inception_v4 ### ### BEGIN mobilenet_v2 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v2_weights_best_acc.tar' -#model = mobilenet_v2(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v2_weights_best_acc.tar' +# model = mobilenet_v2(num_classes=1081) ### END mobilenet_v2 ### ### BEGIN mobilenet_v3_large OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v3_large_weights_best_acc.tar' -#model = mobilenet_v3_large(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v3_large_weights_best_acc.tar' +# model = mobilenet_v3_large(num_classes=1081) ### END mobilenet_v3_large ### ### BEGIN mobilenet_v3_small OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v3_small_weights_best_acc.tar' -#model = mobilenet_v3_small(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/mobilenet_v3_small_weights_best_acc.tar' +# model = mobilenet_v3_small(num_classes=1081) ### END mobilenet_v3_small ### ### BEGIN resnet18 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/resnet18_weights_best_acc.tar' -#model = resnet18(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/resnet18_weights_best_acc.tar' +# model = resnet18(num_classes=1081) ### END resnet18 ### ### BEGIN resnet34 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/resnet34_weights_best_acc.tar' -#model = resnet34(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/resnet34_weights_best_acc.tar' +# model = resnet34(num_classes=1081) ### END resnet34 ### ### BEGIN resnet50 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/resnet50_weights_best_acc.tar' -#model = resnet50(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/resnet50_weights_best_acc.tar' +# model = resnet50(num_classes=1081) ### END resnet50 ### ### BEGIN resnet101 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/resnet101_weights_best_acc.tar' -#model = resnet101(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/resnet101_weights_best_acc.tar' +# model = resnet101(num_classes=1081) ### END resnet101 ### ### BEGIN resnet152 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/resnet152_weights_best_acc.tar' -#model = resnet152(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/resnet152_weights_best_acc.tar' +# model = resnet152(num_classes=1081) ### END resnet152 ### ### BEGIN shufflenet_v2_x1_0 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/shufflenet_v2_x1_0_weights_best_acc.tar' -#model = shufflenet_v2_x1_0(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/shufflenet_v2_x1_0_weights_best_acc.tar' +# model = shufflenet_v2_x1_0(num_classes=1081) ### END shufflenet_v2_x1_0 ### ### BEGIN squeezenet1_0 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/squeezenet_weights_best_acc.tar' -#model = squeezenet1_0(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/squeezenet_weights_best_acc.tar' +# model = squeezenet1_0(num_classes=1081) ### END squeezenet1_0 ### ### BEGIN vgg11 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/vgg11_weights_best_acc.tar' -#model = vgg11(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/vgg11_weights_best_acc.tar' +# model = vgg11(num_classes=1081) ### END vgg11 ### ### BEGIN vit_b_16 FAIL ### -#filename = '/srv/ml/deepcrayon/plantnet/models/vit_b_16_weights_best_acc.tar' -#model = vit_b_16(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/vit_b_16_weights_best_acc.tar' +# model = vit_b_16(num_classes=1081) ### END vit ### ### BEGIN wide_resnet50_2 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/wide_resnet50_2_weights_best_acc.tar' -#model = wide_resnet50_2(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/wide_resnet50_2_weights_best_acc.tar' +# model = wide_resnet50_2(num_classes=1081) ### END wide_resnet50_2 ### ### BEGIN wide_resnet101_2 OK ### -#filename = '/srv/ml/deepcrayon/plantnet/models/wide_resnet101_2_weights_best_acc.tar' -#model = wide_resnet101_2(num_classes=1081) +# filename = '/srv/ml/deepcrayon/plantnet/models/wide_resnet101_2_weights_best_acc.tar' +# model = wide_resnet101_2(num_classes=1081) ### END wide_resnet101_2 ### load_model(model, filename=filename, use_gpu=use_gpu) - diff --git a/train.py b/train.py index 3adb0d3..edea134 100755 --- a/train.py +++ b/train.py @@ -38,195 +38,226 @@ import argparse # # Use defaults from git repo example. # https://github.com/plantnet/PlantNet-300K -BATCH_SIZE='32' -MU='0.0001' -K='1 3 5 10' -SEED='4' -IMAGE_SIZE='256' -CROP_SIZE='224' +BATCH_SIZE = "32" +MU = "0.0001" +K = "1 3 5 10" +SEED = "4" +IMAGE_SIZE = "256" +CROP_SIZE = "224" # Root path to images test train val -ROOT_DIR='/srv/ml/plantnet/files/plantnet_300K/images' +ROOT_DIR = "/srv/ml/plantnet/files/plantnet_300K/images" # Use GPU -USE_GPU='1' +USE_GPU = "1" # Use all CPUs available on system. XXX get nproc -NUM_WORKERS='4' +NUM_WORKERS = "4" # Parse command line options parser = argparse.ArgumentParser( - prog='train.py', - description='Train PlantNet-300K models using default parameters.', - epilog='Example: ./train.py alexnet', - ) + prog="train.py", + description="Train PlantNet-300K models using default parameters.", + epilog="Example: ./train.py alexnet", +) -parser.add_argument('model', - help='Model name', - type=str, - choices=['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'inception_resnet_v2', 'inception_v3', 'inception_v4', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'shufflenet_v2_x1_0', 'squeezenet1_0', 'vgg11', 'vit_b_16', 'wide_resnet50_2', 'wide_resnet101_2'], +parser.add_argument( + "model", + help="Model name", + type=str, + choices=[ + "alexnet", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "efficientnet_b0", + "efficientnet_b1", + "efficientnet_b2", + "efficientnet_b3", + "efficientnet_b4", + "inception_resnet_v2", + "inception_v3", + "inception_v4", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "shufflenet_v2_x1_0", + "squeezenet1_0", + "vgg11", + "vit_b_16", + "wide_resnet50_2", + "wide_resnet101_2", + ], ) args = parser.parse_args() MODEL_NAME = args.model # Initial Learning Rate -LR='N.NNNN' +LR = "N.NNNN" # Number of Epochs -N_EPOCHS='NN' +N_EPOCHS = "NN" # First Decay -FIRST_DECAY='1N' +FIRST_DECAY = "1N" # Secon Decay -SECOND_DECAY='2N' +SECOND_DECAY = "2N" # Set LR, epochs, decay hyperparameters based on model. match MODEL_NAME: - case "alexnet": - LR='0.001' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'densenet121': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'densenet161': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'densenet169': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'densenet201': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'efficientnet_b0': - LR='0.01' - N_EPOCHS='20' - FIRST_DECAY='10' - SECOND_DECAY='15' - case 'efficientnet_b1': - LR='0.01' - N_EPOCHS='20' - FIRST_DECAY='10' - SECOND_DECAY='15' - case 'efficientnet_b2': - LR='0.01' - N_EPOCHS='20' - FIRST_DECAY='10' - SECOND_DECAY='15' - case 'efficientnet_b3': - LR='0.01' - N_EPOCHS='20' - FIRST_DECAY='10' - SECOND_DECAY='15' - case 'efficientnet_b4': - LR='0.01' - N_EPOCHS='20' - FIRST_DECAY='10' - SECOND_DECAY='15' - case 'inception_resnet_v2': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'inception_v3': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'inception_v4': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'mobilenet_v2': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'mobilenet_v3_large': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'mobilenet_v3_small': - LR='0.001' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'resnet18': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'resnet34': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'resnet50': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'resnet101': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'resnet152': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'shufflenet_v2_x1_0': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'squeezenet1_0': - LR='0.001' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'vgg11': - LR='0.001' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'vit_b_16': - LR='0.0005' - N_EPOCHS='20' - FIRST_DECAY='15' - SECOND_DECAY='' - case 'wide_resnet50_2': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' - case 'wide_resnet101_2': - LR='0.01' - N_EPOCHS='30' - FIRST_DECAY='20' - SECOND_DECAY='25' + case "alexnet": + LR = "0.001" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "densenet121": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "densenet161": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "densenet169": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "densenet201": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "efficientnet_b0": + LR = "0.01" + N_EPOCHS = "20" + FIRST_DECAY = "10" + SECOND_DECAY = "15" + case "efficientnet_b1": + LR = "0.01" + N_EPOCHS = "20" + FIRST_DECAY = "10" + SECOND_DECAY = "15" + case "efficientnet_b2": + LR = "0.01" + N_EPOCHS = "20" + FIRST_DECAY = "10" + SECOND_DECAY = "15" + case "efficientnet_b3": + LR = "0.01" + N_EPOCHS = "20" + FIRST_DECAY = "10" + SECOND_DECAY = "15" + case "efficientnet_b4": + LR = "0.01" + N_EPOCHS = "20" + FIRST_DECAY = "10" + SECOND_DECAY = "15" + case "inception_resnet_v2": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "inception_v3": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "inception_v4": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "mobilenet_v2": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "mobilenet_v3_large": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "mobilenet_v3_small": + LR = "0.001" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "resnet18": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "resnet34": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "resnet50": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "resnet101": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "resnet152": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "shufflenet_v2_x1_0": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "squeezenet1_0": + LR = "0.001" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "vgg11": + LR = "0.001" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "vit_b_16": + LR = "0.0005" + N_EPOCHS = "20" + FIRST_DECAY = "15" + SECOND_DECAY = "" + case "wide_resnet50_2": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" + case "wide_resnet101_2": + LR = "0.01" + N_EPOCHS = "30" + FIRST_DECAY = "20" + SECOND_DECAY = "25" -print('python main.py', - '--model=' + MODEL_NAME, - '--lr=' + LR, - '--n_epochs=' + N_EPOCHS, - '--epoch_decay=' + FIRST_DECAY, SECOND_DECAY, - '--batch_size=' + BATCH_SIZE, - '--mu=' + MU, - '--k=' + K, - '--pretrained', - '--seed=' + SEED, - '--image_size=' + IMAGE_SIZE, - '--crop_size=' + CROP_SIZE, - '--root=' + ROOT_DIR, - '--use_gpu=' + USE_GPU, - '--num_workers=' + NUM_WORKERS, - '--save_name_xp=' + MODEL_NAME, +print( + "python main.py", + "--model=" + MODEL_NAME, + "--lr=" + LR, + "--n_epochs=" + N_EPOCHS, + "--epoch_decay=" + FIRST_DECAY, + SECOND_DECAY, + "--batch_size=" + BATCH_SIZE, + "--mu=" + MU, + "--k=" + K, + "--pretrained", + "--seed=" + SEED, + "--image_size=" + IMAGE_SIZE, + "--crop_size=" + CROP_SIZE, + "--root=" + ROOT_DIR, + "--use_gpu=" + USE_GPU, + "--num_workers=" + NUM_WORKERS, + "--save_name_xp=" + MODEL_NAME, )