Save plot of training accuracy

master
Jeff Moe 2022-08-16 22:21:15 -06:00
parent 33adccb2cb
commit 8dd443ab30
1 changed files with 13 additions and 14 deletions

View File

@ -3,7 +3,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# wut-train --- What U Think? SatNOGS Observation AI, training application.\n",
@ -90,11 +92,10 @@
"outputs": [],
"source": [
"ENCODING='FSK9k6'\n",
"batch_size = 64\n",
"epochs = 4\n",
"# Failing with this now:\n",
"#batch_size = 128\n",
"#batch_size = 64\n",
"#epochs = 4\n",
"batch_size = 128\n",
"epochs = 4\n",
"IMG_WIDTH = 416\n",
"IMG_HEIGHT = 803"
]
@ -133,8 +134,8 @@
"print('Validation images: ', total_val)\n",
"print('')\n",
"print('Reduce training and validation set')\n",
"total_train = 1000\n",
"total_val = 1000\n",
"total_train = 5000\n",
"total_val = 5000\n",
"print('Training reduced to: ', total_train)\n",
"print('Validation reduced to: ', total_val)"
]
@ -206,7 +207,6 @@
"metadata": {},
"outputs": [],
"source": [
"# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n",
"def plotImages(images_arr):\n",
" fig, axes = plt.subplots(1, 3, figsize=(20,20))\n",
" axes = axes.flatten()\n",
@ -265,8 +265,6 @@
"metadata": {},
"outputs": [],
"source": [
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
]
},
@ -295,11 +293,8 @@
"metadata": {},
"outputs": [],
"source": [
"#wutoptimizer = 'adam'\n",
"wutoptimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True)\n",
"\n",
"wutloss = 'binary_crossentropy'\n",
"#wutmetrics = 'accuracy'\n",
"wutmetrics = ['accuracy']"
]
},
@ -361,7 +356,6 @@
"metadata": {},
"outputs": [],
"source": [
"# Need ~64 gigs RAM+, 20 gig disk\n",
"history = model.fit(\n",
" train_data_gen,\n",
" steps_per_epoch=total_train // batch_size,\n",
@ -404,6 +398,10 @@
"\n",
"epochs_range = range(epochs)\n",
"\n",
"plot_file=(\"wut-plot-\" + ENCODING + \".png\")\n",
"save_path_plot = os.path.join('/srv/satnogs/data/models/', ENCODING, plot_file)\n",
"print(save_path_plot)\n",
"\n",
"plt.figure(figsize=(8, 8))\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc, label='Training Accuracy')\n",
@ -416,6 +414,7 @@
"plt.plot(epochs_range, val_loss, label='Validation Loss')\n",
"plt.legend(loc='upper right')\n",
"plt.title('Training and Validation Loss')\n",
"plt.savefig(save_path_plot)\n",
"plt.show()"
]
},