Save plot of training accuracy
parent
33adccb2cb
commit
8dd443ab30
|
@ -3,7 +3,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# wut-train --- What U Think? SatNOGS Observation AI, training application.\n",
|
||||
|
@ -90,11 +92,10 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"ENCODING='FSK9k6'\n",
|
||||
"batch_size = 64\n",
|
||||
"epochs = 4\n",
|
||||
"# Failing with this now:\n",
|
||||
"#batch_size = 128\n",
|
||||
"#batch_size = 64\n",
|
||||
"#epochs = 4\n",
|
||||
"batch_size = 128\n",
|
||||
"epochs = 4\n",
|
||||
"IMG_WIDTH = 416\n",
|
||||
"IMG_HEIGHT = 803"
|
||||
]
|
||||
|
@ -133,8 +134,8 @@
|
|||
"print('Validation images: ', total_val)\n",
|
||||
"print('')\n",
|
||||
"print('Reduce training and validation set')\n",
|
||||
"total_train = 1000\n",
|
||||
"total_val = 1000\n",
|
||||
"total_train = 5000\n",
|
||||
"total_val = 5000\n",
|
||||
"print('Training reduced to: ', total_train)\n",
|
||||
"print('Validation reduced to: ', total_val)"
|
||||
]
|
||||
|
@ -206,7 +207,6 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n",
|
||||
"def plotImages(images_arr):\n",
|
||||
" fig, axes = plt.subplots(1, 3, figsize=(20,20))\n",
|
||||
" axes = axes.flatten()\n",
|
||||
|
@ -265,8 +265,6 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
|
||||
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
|
||||
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
|
||||
]
|
||||
},
|
||||
|
@ -295,11 +293,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#wutoptimizer = 'adam'\n",
|
||||
"wutoptimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True)\n",
|
||||
"\n",
|
||||
"wutloss = 'binary_crossentropy'\n",
|
||||
"#wutmetrics = 'accuracy'\n",
|
||||
"wutmetrics = ['accuracy']"
|
||||
]
|
||||
},
|
||||
|
@ -361,7 +356,6 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Need ~64 gigs RAM+, 20 gig disk\n",
|
||||
"history = model.fit(\n",
|
||||
" train_data_gen,\n",
|
||||
" steps_per_epoch=total_train // batch_size,\n",
|
||||
|
@ -404,6 +398,10 @@
|
|||
"\n",
|
||||
"epochs_range = range(epochs)\n",
|
||||
"\n",
|
||||
"plot_file=(\"wut-plot-\" + ENCODING + \".png\")\n",
|
||||
"save_path_plot = os.path.join('/srv/satnogs/data/models/', ENCODING, plot_file)\n",
|
||||
"print(save_path_plot)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(8, 8))\n",
|
||||
"plt.subplot(1, 2, 1)\n",
|
||||
"plt.plot(epochs_range, acc, label='Training Accuracy')\n",
|
||||
|
@ -416,6 +414,7 @@
|
|||
"plt.plot(epochs_range, val_loss, label='Validation Loss')\n",
|
||||
"plt.legend(loc='upper right')\n",
|
||||
"plt.title('Training and Validation Loss')\n",
|
||||
"plt.savefig(save_path_plot)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue