Draft, print training parameters
parent
4ec07e418a
commit
680d9a093f
67
train.py
67
train.py
|
@ -26,3 +26,70 @@
|
|||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
|
||||
import argparse
|
||||
|
||||
# Python version used in paper ?
|
||||
# Set hyperparameters based on values in Table 3 of
|
||||
# Pl@ntNet-300K paper.
|
||||
#
|
||||
# Initial Learning Rate
|
||||
LR='N.NNNN'
|
||||
# Number of Epochs
|
||||
N_EPOCHS='NN'
|
||||
# First Decay
|
||||
FIRST_DECAY='1N'
|
||||
# Secon Decay
|
||||
SECOND_DECAY='2N'
|
||||
|
||||
# Use defaults from git repo example.
|
||||
# https://github.com/plantnet/PlantNet-300K
|
||||
BATCH_SIZE='32'
|
||||
MU='0.0001'
|
||||
K='1 3 5 10'
|
||||
SEED='4'
|
||||
IMAGE_SIZE='256'
|
||||
CROP_SIZE='224'
|
||||
|
||||
# Root path to images test train val
|
||||
ROOT_DIR='/srv/ml/plantnet/files/plantnet_300K/images'
|
||||
# Use GPU
|
||||
USE_GPU='1'
|
||||
# Use all CPUs available on system. XXX get nproc
|
||||
NUM_WORKERS='4'
|
||||
|
||||
# Parse command line options
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='train.py',
|
||||
description='Train PlantNet-300K models using default parameters.',
|
||||
epilog='Example: ./train.py alexnet',
|
||||
)
|
||||
|
||||
parser.add_argument('model',
|
||||
help='Model name',
|
||||
type=str,
|
||||
choices=['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'inception_resnet_v2', 'inception_v3', 'inception_v4', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'shufflenet_v2_x1_0', 'squeezenet', 'squeezenet1_0', 'vgg11', 'vit_b_16', 'wide_resnet50_2', 'wide_resnet101_2'],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
MODEL_NAME = args.model
|
||||
|
||||
# XXX
|
||||
# Set LR, epochs, decay hyperparameters based on model.
|
||||
|
||||
print('python main.py',
|
||||
'--model=' + MODEL_NAME,
|
||||
'--lr=' + LR,
|
||||
'--n_epochs=' + N_EPOCHS,
|
||||
'--epoch_decay=' + FIRST_DECAY, SECOND_DECAY,
|
||||
'--batch_size=' + BATCH_SIZE,
|
||||
'--mu=' + MU,
|
||||
'--k=' + K,
|
||||
'--pretrained',
|
||||
'--seed=' + SEED,
|
||||
'--image_size=' + IMAGE_SIZE,
|
||||
'--crop_size=' + CROP_SIZE,
|
||||
'--root=' + ROOT_DIR,
|
||||
'--use_gpu=' + USE_GPU,
|
||||
'--num_workers=' + NUM_WORKERS,
|
||||
'--save_name_xp=' + MODEL_NAME,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue