use tf keras, fix batch size
parent
8c7b80083b
commit
9c6ab722e5
|
@ -24,20 +24,40 @@
|
|||
"from __future__ import print_function\n",
|
||||
"import os\n",
|
||||
"import datetime\n",
|
||||
"import numpy as np\n",
|
||||
"import keras\n",
|
||||
"from keras import Sequential\n",
|
||||
"from keras.layers import Activation, Dropout, Flatten, Dense\n",
|
||||
"from keras.preprocessing.image import ImageDataGenerator\n",
|
||||
"from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
|
||||
"from keras import optimizers\n",
|
||||
"from keras.preprocessing import image\n",
|
||||
"from keras.models import load_model\n",
|
||||
"#from keras.preprocessing.image import load_img\n",
|
||||
"#from keras.preprocessing.image import img_to_array\n",
|
||||
"from keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
|
||||
"from keras.models import Model\n",
|
||||
"from keras.layers import Input, concatenate\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from tensorflow.keras import layers\n",
|
||||
"\n",
|
||||
"from tensorflow.keras import Sequential\n",
|
||||
"from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense\n",
|
||||
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
|
||||
"from tensorflow.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
|
||||
"from tensorflow.keras import optimizers\n",
|
||||
"from tensorflow.keras.preprocessing import image\n",
|
||||
"from tensorflow.keras.models import load_model\n",
|
||||
"from tensorflow.keras.preprocessing.image import load_img\n",
|
||||
"from tensorflow.keras.preprocessing.image import img_to_array\n",
|
||||
"from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
|
||||
"from tensorflow.keras.models import Model\n",
|
||||
"from tensorflow.keras.layers import Input, concatenate\n",
|
||||
"from tensorflow.keras.utils import model_to_dot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualization\n",
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
|
@ -50,9 +70,8 @@
|
|||
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
|
||||
"import ipywidgets as widgets\n",
|
||||
"# Display Images\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from IPython.display import display, Image"
|
||||
"from IPython.display import display, Image\n",
|
||||
"from IPython.display import SVG"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -62,8 +81,11 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"ENCODING='GMSK'\n",
|
||||
"batch_size = 128\n",
|
||||
"batch_size = 64\n",
|
||||
"epochs = 4\n",
|
||||
"# Failing with this now:\n",
|
||||
"#batch_size = 128\n",
|
||||
"#epochs = 4\n",
|
||||
"IMG_WIDTH = 416\n",
|
||||
"IMG_HEIGHT = 803"
|
||||
]
|
||||
|
@ -208,6 +230,17 @@
|
|||
"plotImages(sample_val_images[0:3])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If you need to kill tensorboad, when it says stuff like this:\n",
|
||||
"# Reusing TensorBoard on port 6006 (pid 13650), started 0:04:20 ago. (Use '!kill 13650' to kill it.)\n",
|
||||
"#!rm -rf /tmp/.tensorboard-info/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@ -218,10 +251,18 @@
|
|||
"!rm -rf ./logs/\n",
|
||||
"os.mkdir(\"logs\")\n",
|
||||
"log_dir = \"logs\"\n",
|
||||
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
|
||||
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
|
||||
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
|
||||
"tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
|
||||
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
|
||||
"#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
|
||||
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -301,15 +342,30 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(train_data_gen)\n",
|
||||
"print(total_train)\n",
|
||||
"print(batch_size)\n",
|
||||
"print(epochs)\n",
|
||||
"print(val_data_gen)\n",
|
||||
"print(total_val)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Need ~64 gigs RAM+, 20 gig disk\n",
|
||||
"history = model.fit(\n",
|
||||
" train_data_gen,\n",
|
||||
" steps_per_epoch=total_train // batch_size,\n",
|
||||
" epochs=epochs,\n",
|
||||
" verbose=1,\n",
|
||||
" callbacks=[tensorboard_callback],\n",
|
||||
" validation_data=val_data_gen,\n",
|
||||
" validation_steps=total_val // batch_size,\n",
|
||||
" shuffle=True,\n",
|
||||
" callbacks=[tensorboard_callback],\n",
|
||||
" use_multiprocessing=False\n",
|
||||
")"
|
||||
]
|
||||
|
@ -320,9 +376,24 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"acc = history.history['accuracy']\n",
|
||||
"val_acc = history.history['val_accuracy']\n",
|
||||
"\n",
|
||||
"acc = history.history['accuracy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"val_acc = history.history['val_accuracy']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loss = history.history['loss']\n",
|
||||
"val_loss = history.history['val_loss']\n",
|
||||
"\n",
|
||||
|
@ -401,9 +472,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#from IPython.display import SVG\n",
|
||||
"#from tensorflow.keras.utils import model_to_dot\n",
|
||||
"#SVG(model_to_dot(model).create(prog='dot', format='svg'))"
|
||||
"SVG(model_to_dot(model).create(prog='dot', format='svg'))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue