From 8dd443ab304f032612bf3b13257e551ec1cea615 Mon Sep 17 00:00:00 2001 From: Jeff Moe Date: Tue, 16 Aug 2022 22:21:15 -0600 Subject: [PATCH] Save plot of training accuracy --- notebooks/wut-train.ipynb | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/notebooks/wut-train.ipynb b/notebooks/wut-train.ipynb index c3269bb..be2d819 100644 --- a/notebooks/wut-train.ipynb +++ b/notebooks/wut-train.ipynb @@ -3,7 +3,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# wut-train --- What U Think? SatNOGS Observation AI, training application.\n", @@ -90,11 +92,10 @@ "outputs": [], "source": [ "ENCODING='FSK9k6'\n", - "batch_size = 64\n", - "epochs = 4\n", - "# Failing with this now:\n", - "#batch_size = 128\n", + "#batch_size = 64\n", "#epochs = 4\n", + "batch_size = 128\n", + "epochs = 4\n", "IMG_WIDTH = 416\n", "IMG_HEIGHT = 803" ] @@ -133,8 +134,8 @@ "print('Validation images: ', total_val)\n", "print('')\n", "print('Reduce training and validation set')\n", - "total_train = 1000\n", - "total_val = 1000\n", + "total_train = 5000\n", + "total_val = 5000\n", "print('Training reduced to: ', total_train)\n", "print('Validation reduced to: ', total_val)" ] @@ -206,7 +207,6 @@ "metadata": {}, "outputs": [], "source": [ - "# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n", "def plotImages(images_arr):\n", " fig, axes = plt.subplots(1, 3, figsize=(20,20))\n", " axes = axes.flatten()\n", @@ -265,8 +265,6 @@ "metadata": {}, "outputs": [], "source": [ - "#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", - "#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n", "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')" ] }, @@ -295,11 +293,8 @@ "metadata": {}, "outputs": [], "source": [ - "#wutoptimizer = 'adam'\n", "wutoptimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True)\n", - "\n", "wutloss = 'binary_crossentropy'\n", - "#wutmetrics = 'accuracy'\n", "wutmetrics = ['accuracy']" ] }, @@ -361,7 +356,6 @@ "metadata": {}, "outputs": [], "source": [ - "# Need ~64 gigs RAM+, 20 gig disk\n", "history = model.fit(\n", " train_data_gen,\n", " steps_per_epoch=total_train // batch_size,\n", @@ -404,6 +398,10 @@ "\n", "epochs_range = range(epochs)\n", "\n", + "plot_file=(\"wut-plot-\" + ENCODING + \".png\")\n", + "save_path_plot = os.path.join('/srv/satnogs/data/models/', ENCODING, plot_file)\n", + "print(save_path_plot)\n", + "\n", "plt.figure(figsize=(8, 8))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs_range, acc, label='Training Accuracy')\n", @@ -416,6 +414,7 @@ "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.legend(loc='upper right')\n", "plt.title('Training and Validation Loss')\n", + "plt.savefig(save_path_plot)\n", "plt.show()" ] },