tensorboard train cluster

master
ml server 2020-01-17 23:20:36 -07:00
parent b085401e6e
commit cb693d6cc1
1 changed files with 17 additions and 21 deletions

View File

@ -26,7 +26,7 @@
"import os\n",
"import numpy as np\n",
"import simplejson as json\n",
"from datetime import datetime"
"import datetime"
]
},
{
@ -89,8 +89,8 @@
"source": [
"IMG_HEIGHT = 416\n",
"IMG_WIDTH= 804\n",
"batch_size = 1\n",
"epochs = 1\n",
"batch_size = 4\n",
"epochs = 4\n",
"# Full size, machine barfs probably needs more RAM\n",
"#IMG_HEIGHT = 832\n",
"#IMG_WIDTH = 1606\n",
@ -163,8 +163,8 @@
"source": [
"print(\"--\")\n",
"print(\"Reduce training and validation set when testing\")\n",
"total_train = 1\n",
"total_val = 1\n",
"total_train = 16\n",
"total_val = 16\n",
"print(\"Reduced training images:\", total_train)\n",
"print(\"Reduced validation images:\", total_val)"
]
@ -221,9 +221,18 @@
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"#os.mkdir(\"cluster-logs\")\n",
"logdir = \"cluster-logs\"\n",
"tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=logdir)"
"!rm -rf ./cluster-logs/\n",
"log_dir=\"cluster-logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir cluster-logs/fit"
]
},
{
@ -286,15 +295,6 @@
"Image.LOAD_TRUNCATED_IMAGES = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%tensorboard --logdir cluster-logs/"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -303,19 +303,15 @@
"source": [
"acc = history.history['accuracy']\n",
"val_acc = history.history['val_accuracy']\n",
"\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
"epochs_range = range(epochs)\n",
"\n",
"plt.figure(figsize=(8, 8))\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc, label='Training Accuracy')\n",
"plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n",
"plt.legend(loc='lower right')\n",
"plt.title('Training and Validation Accuracy')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(epochs_range, loss, label='Training Loss')\n",
"plt.plot(epochs_range, val_loss, label='Validation Loss')\n",