use tf keras, fix batch size

master
root 2022-05-29 20:25:38 -06:00
parent 8c7b80083b
commit 9c6ab722e5
1 changed files with 98 additions and 29 deletions

View File

@ -24,20 +24,40 @@
"from __future__ import print_function\n",
"import os\n",
"import datetime\n",
"import numpy as np\n",
"import keras\n",
"from keras import Sequential\n",
"from keras.layers import Activation, Dropout, Flatten, Dense\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
"from keras import optimizers\n",
"from keras.preprocessing import image\n",
"from keras.models import load_model\n",
"#from keras.preprocessing.image import load_img\n",
"#from keras.preprocessing.image import img_to_array\n",
"from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
"from keras.models import Model\n",
"from keras.layers import Input, concatenate\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"from tensorflow.keras import Sequential\n",
"from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense\n",
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"from tensorflow.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
"from tensorflow.keras import optimizers\n",
"from tensorflow.keras.preprocessing import image\n",
"from tensorflow.keras.models import load_model\n",
"from tensorflow.keras.preprocessing.image import load_img\n",
"from tensorflow.keras.preprocessing.image import img_to_array\n",
"from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.layers import Input, concatenate\n",
"from tensorflow.keras.utils import model_to_dot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualization\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
@ -50,9 +70,8 @@
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
"import ipywidgets as widgets\n",
"# Display Images\n",
"\n",
"\n",
"from IPython.display import display, Image"
"from IPython.display import display, Image\n",
"from IPython.display import SVG"
]
},
{
@ -62,8 +81,11 @@
"outputs": [],
"source": [
"ENCODING='GMSK'\n",
"batch_size = 128\n",
"batch_size = 64\n",
"epochs = 4\n",
"# Failing with this now:\n",
"#batch_size = 128\n",
"#epochs = 4\n",
"IMG_WIDTH = 416\n",
"IMG_HEIGHT = 803"
]
@ -208,6 +230,17 @@
"plotImages(sample_val_images[0:3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If you need to kill tensorboad, when it says stuff like this:\n",
"# Reusing TensorBoard on port 6006 (pid 13650), started 0:04:20 ago. (Use '!kill 13650' to kill it.)\n",
"#!rm -rf /tmp/.tensorboard-info/"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -218,10 +251,18 @@
"!rm -rf ./logs/\n",
"os.mkdir(\"logs\")\n",
"log_dir = \"logs\"\n",
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
]
},
{
@ -301,15 +342,30 @@
"metadata": {},
"outputs": [],
"source": [
"print(train_data_gen)\n",
"print(total_train)\n",
"print(batch_size)\n",
"print(epochs)\n",
"print(val_data_gen)\n",
"print(total_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Need ~64 gigs RAM+, 20 gig disk\n",
"history = model.fit(\n",
" train_data_gen,\n",
" steps_per_epoch=total_train // batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" callbacks=[tensorboard_callback],\n",
" validation_data=val_data_gen,\n",
" validation_steps=total_val // batch_size,\n",
" shuffle=True,\n",
" callbacks=[tensorboard_callback],\n",
" use_multiprocessing=False\n",
")"
]
@ -320,9 +376,24 @@
"metadata": {},
"outputs": [],
"source": [
"acc = history.history['accuracy']\n",
"val_acc = history.history['val_accuracy']\n",
"\n",
"acc = history.history['accuracy']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_acc = history.history['val_accuracy']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
@ -401,9 +472,7 @@
"metadata": {},
"outputs": [],
"source": [
"#from IPython.display import SVG\n",
"#from tensorflow.keras.utils import model_to_dot\n",
"#SVG(model_to_dot(model).create(prog='dot', format='svg'))"
"SVG(model_to_dot(model).create(prog='dot', format='svg'))"
]
}
],