use tf keras, fix batch size

master
root 2022-05-29 20:25:38 -06:00
parent 8c7b80083b
commit 9c6ab722e5
1 changed files with 98 additions and 29 deletions

View File

@ -24,20 +24,40 @@
"from __future__ import print_function\n", "from __future__ import print_function\n",
"import os\n", "import os\n",
"import datetime\n", "import datetime\n",
"import numpy as np\n", "import numpy as np"
"import keras\n", ]
"from keras import Sequential\n", },
"from keras.layers import Activation, Dropout, Flatten, Dense\n", {
"from keras.preprocessing.image import ImageDataGenerator\n", "cell_type": "code",
"from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n", "execution_count": null,
"from keras import optimizers\n", "metadata": {},
"from keras.preprocessing import image\n", "outputs": [],
"from keras.models import load_model\n", "source": [
"#from keras.preprocessing.image import load_img\n", "import tensorflow as tf\n",
"#from keras.preprocessing.image import img_to_array\n", "from tensorflow import keras\n",
"from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n", "from tensorflow.keras import layers\n",
"from keras.models import Model\n", "\n",
"from keras.layers import Input, concatenate\n", "from tensorflow.keras import Sequential\n",
"from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense\n",
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"from tensorflow.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
"from tensorflow.keras import optimizers\n",
"from tensorflow.keras.preprocessing import image\n",
"from tensorflow.keras.models import load_model\n",
"from tensorflow.keras.preprocessing.image import load_img\n",
"from tensorflow.keras.preprocessing.image import img_to_array\n",
"from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, concatenate\n",
"from tensorflow.keras.utils import model_to_dot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualization\n", "# Visualization\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
@ -50,9 +70,8 @@
"from ipywidgets import interact, interactive, fixed, interact_manual\n", "from ipywidgets import interact, interactive, fixed, interact_manual\n",
"import ipywidgets as widgets\n", "import ipywidgets as widgets\n",
"# Display Images\n", "# Display Images\n",
"\n", "from IPython.display import display, Image\n",
"\n", "from IPython.display import SVG"
"from IPython.display import display, Image"
] ]
}, },
{ {
@ -62,8 +81,11 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"ENCODING='GMSK'\n", "ENCODING='GMSK'\n",
"batch_size = 128\n", "batch_size = 64\n",
"epochs = 4\n", "epochs = 4\n",
"# Failing with this now:\n",
"#batch_size = 128\n",
"#epochs = 4\n",
"IMG_WIDTH = 416\n", "IMG_WIDTH = 416\n",
"IMG_HEIGHT = 803" "IMG_HEIGHT = 803"
] ]
@ -208,6 +230,17 @@
"plotImages(sample_val_images[0:3])" "plotImages(sample_val_images[0:3])"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If you need to kill tensorboad, when it says stuff like this:\n",
"# Reusing TensorBoard on port 6006 (pid 13650), started 0:04:20 ago. (Use '!kill 13650' to kill it.)\n",
"#!rm -rf /tmp/.tensorboard-info/"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -218,10 +251,18 @@
"!rm -rf ./logs/\n", "!rm -rf ./logs/\n",
"os.mkdir(\"logs\")\n", "os.mkdir(\"logs\")\n",
"log_dir = \"logs\"\n", "log_dir = \"logs\"\n",
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", ]
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n", },
"tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')" {
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
] ]
}, },
{ {
@ -301,15 +342,30 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(train_data_gen)\n",
"print(total_train)\n",
"print(batch_size)\n",
"print(epochs)\n",
"print(val_data_gen)\n",
"print(total_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Need ~64 gigs RAM+, 20 gig disk\n",
"history = model.fit(\n", "history = model.fit(\n",
" train_data_gen,\n", " train_data_gen,\n",
" steps_per_epoch=total_train // batch_size,\n", " steps_per_epoch=total_train // batch_size,\n",
" epochs=epochs,\n", " epochs=epochs,\n",
" verbose=1,\n", " verbose=1,\n",
" callbacks=[tensorboard_callback],\n",
" validation_data=val_data_gen,\n", " validation_data=val_data_gen,\n",
" validation_steps=total_val // batch_size,\n", " validation_steps=total_val // batch_size,\n",
" shuffle=True,\n", " shuffle=True,\n",
" callbacks=[tensorboard_callback],\n",
" use_multiprocessing=False\n", " use_multiprocessing=False\n",
")" ")"
] ]
@ -320,9 +376,24 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"acc = history.history['accuracy']\n", "acc = history.history['accuracy']"
"val_acc = history.history['val_accuracy']\n", ]
"\n", },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_acc = history.history['val_accuracy']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = history.history['loss']\n", "loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n", "val_loss = history.history['val_loss']\n",
"\n", "\n",
@ -401,9 +472,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#from IPython.display import SVG\n", "SVG(model_to_dot(model).create(prog='dot', format='svg'))"
"#from tensorflow.keras.utils import model_to_dot\n",
"#SVG(model_to_dot(model).create(prog='dot', format='svg'))"
] ]
} }
], ],