pep8/black format jupyter

main
jebba 2022-01-25 09:10:20 -07:00
parent 181ea45f15
commit 9a5958ec49
3 changed files with 126 additions and 99 deletions

View File

@ -165,7 +165,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = load_model('data/models/witzit-al.tf')" "model = load_model(\"data/models/witzit-al.tf\")"
] ]
}, },
{ {
@ -174,7 +174,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_dir = os.path.join('data/', 'test')" "test_dir = os.path.join(\"data/\", \"test\")"
] ]
}, },
{ {
@ -193,8 +193,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Good results\n", "# Good results\n",
"#batch_size = 128\n", "# batch_size = 128\n",
"#epochs = 6\n", "# epochs = 6\n",
"# Testing, faster more inaccurate results\n", "# Testing, faster more inaccurate results\n",
"batch_size = 32\n", "batch_size = 32\n",
"epochs = 3" "epochs = 3"
@ -208,10 +208,10 @@
"source": [ "source": [
"# Half size\n", "# Half size\n",
"IMG_HEIGHT = 416\n", "IMG_HEIGHT = 416\n",
"IMG_WIDTH= 804\n", "IMG_WIDTH = 804\n",
"# Full size, machine barfs probably needs more RAM\n", "# Full size, machine barfs probably needs more RAM\n",
"#IMG_HEIGHT = 832\n", "# IMG_HEIGHT = 832\n",
"#IMG_WIDTH = 1606" "# IMG_WIDTH = 1606"
] ]
}, },
{ {
@ -220,9 +220,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_image_generator = ImageDataGenerator(\n", "test_image_generator = ImageDataGenerator(rescale=1.0 / 255)"
" rescale=1./255\n",
")"
] ]
}, },
{ {
@ -240,10 +238,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_data_gen = test_image_generator.flow_from_directory(batch_size=batch_size,\n", "test_data_gen = test_image_generator.flow_from_directory(\n",
" directory=test_dir,\n", " batch_size=batch_size,\n",
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n", " directory=test_dir,\n",
" class_mode='binary')" " target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
" class_mode=\"binary\",\n",
")"
] ]
}, },
{ {
@ -263,11 +263,11 @@
"source": [ "source": [
"# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n", "# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n",
"def plotImages(images_arr):\n", "def plotImages(images_arr):\n",
" fig, axes = plt.subplots(1, 3, figsize=(20,20))\n", " fig, axes = plt.subplots(1, 3, figsize=(20, 20))\n",
" axes = axes.flatten()\n", " axes = axes.flatten()\n",
" for img, ax in zip( images_arr, axes):\n", " for img, ax in zip(images_arr, axes):\n",
" ax.imshow(img)\n", " ax.imshow(img)\n",
" ax.axis('off')\n", " ax.axis(\"off\")\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.show()" " plt.show()"
] ]
@ -297,9 +297,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#pred=model.predict_generator(test_data_gen,\n", "# pred=model.predict_generator(test_data_gen,\n",
"#steps=1,\n", "# steps=1,\n",
"#verbose=1)" "# verbose=1)"
] ]
}, },
{ {
@ -308,10 +308,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"prediction = model.predict(\n", "prediction = model.predict(x=test_data_gen, verbose=1)\n",
" x=test_data_gen,\n",
" verbose=1\n",
")\n",
"print(\"end predict\")" "print(\"end predict\")"
] ]
}, },
@ -321,7 +318,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"predictions=[]" "predictions = []"
] ]
}, },
{ {
@ -340,7 +337,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"prediction_bool = (prediction >0.8)\n", "prediction_bool = prediction > 0.8\n",
"print(prediction_bool)" "print(prediction_bool)"
] ]
}, },
@ -363,10 +360,10 @@
"# Make final prediction\n", "# Make final prediction\n",
"# XXX, display name, display all of them with mini waterfall, etc.\n", "# XXX, display name, display all of them with mini waterfall, etc.\n",
"if prediction_bool[0] == False:\n", "if prediction_bool[0] == False:\n",
" rating = 'bad'\n", " rating = \"bad\"\n",
"else:\n", "else:\n",
" rating = 'good'\n", " rating = \"good\"\n",
"print('Observation: %s' % (rating))" "print(\"Observation: %s\" % (rating))"
] ]
}, },
{ {

View File

@ -55,17 +55,21 @@
"from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n", "from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D\n",
"from tensorflow.python.keras.models import Model\n", "from tensorflow.python.keras.models import Model\n",
"from tensorflow.python.keras.layers import Input, concatenate\n", "from tensorflow.python.keras.layers import Input, concatenate\n",
"\n",
"# Visualization\n", "# Visualization\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"from sklearn.decomposition import PCA\n", "from sklearn.decomposition import PCA\n",
"\n",
"# Seaborn pip dependency\n", "# Seaborn pip dependency\n",
"import seaborn as sns\n", "import seaborn as sns\n",
"\n",
"# Interact\n", "# Interact\n",
"# https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html\n", "# https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html\n",
"from ipywidgets import interact, interactive, fixed, interact_manual\n", "from ipywidgets import interact, interactive, fixed, interact_manual\n",
"import ipywidgets as widgets\n", "import ipywidgets as widgets\n",
"\n",
"# Display Images\n", "# Display Images\n",
"from IPython.display import display, Image" "from IPython.display import display, Image"
] ]
@ -76,7 +80,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"ELEMENT='al'\n", "ELEMENT = \"al\"\n",
"batch_size = 128\n", "batch_size = 128\n",
"epochs = 4\n", "epochs = 4\n",
"IMG_WIDTH = 416\n", "IMG_WIDTH = 416\n",
@ -89,13 +93,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"train_dir = os.path.join('/srv/witzit/data/element', ELEMENT )\n", "train_dir = os.path.join(\"/srv/witzit/data/element\", ELEMENT)\n",
"train_dir = os.path.join('/srv/witzit/data/element', ELEMENT, 'train')\n", "train_dir = os.path.join(\"/srv/witzit/data/element\", ELEMENT, \"train\")\n",
"val_dir = os.path.join('/srv/witzit/data/element', ELEMENT, 'val')\n", "val_dir = os.path.join(\"/srv/witzit/data/element\", ELEMENT, \"val\")\n",
"train_good_dir = os.path.join(train_dir, 'good')\n", "train_good_dir = os.path.join(train_dir, \"good\")\n",
"train_bad_dir = os.path.join(train_dir, 'bad')\n", "train_bad_dir = os.path.join(train_dir, \"bad\")\n",
"val_good_dir = os.path.join(val_dir, 'good')\n", "val_good_dir = os.path.join(val_dir, \"good\")\n",
"val_bad_dir = os.path.join(val_dir, 'bad')\n", "val_bad_dir = os.path.join(val_dir, \"bad\")\n",
"num_train_good = len(os.listdir(train_good_dir))\n", "num_train_good = len(os.listdir(train_good_dir))\n",
"num_train_bad = len(os.listdir(train_bad_dir))\n", "num_train_bad = len(os.listdir(train_bad_dir))\n",
"num_val_good = len(os.listdir(val_good_dir))\n", "num_val_good = len(os.listdir(val_good_dir))\n",
@ -110,12 +114,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"print('total training good images:', num_train_good)\n", "print(\"total training good images:\", num_train_good)\n",
"print('total training bad images:', num_train_bad)\n", "print(\"total training bad images:\", num_train_bad)\n",
"print(\"--\")\n", "print(\"--\")\n",
"print(\"Total training images:\", total_train)\n", "print(\"Total training images:\", total_train)\n",
"print('total validation good images:', num_val_good)\n", "print(\"total validation good images:\", num_val_good)\n",
"print('total validation bad images:', num_val_bad)\n", "print(\"total validation bad images:\", num_val_bad)\n",
"print(\"--\")\n", "print(\"--\")\n",
"print(\"Total validation images:\", total_val)\n", "print(\"Total validation images:\", total_val)\n",
"print(\"Reduce training and validation set when testing\")\n", "print(\"Reduce training and validation set when testing\")\n",
@ -133,7 +137,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"train_image_generator = ImageDataGenerator( rescale=1./255 )" "train_image_generator = ImageDataGenerator(rescale=1.0 / 255)"
] ]
}, },
{ {
@ -142,7 +146,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"val_image_generator = ImageDataGenerator( rescale=1./255 )" "val_image_generator = ImageDataGenerator(rescale=1.0 / 255)"
] ]
}, },
{ {
@ -151,11 +155,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n", "train_data_gen = train_image_generator.flow_from_directory(\n",
" directory=train_dir,\n", " batch_size=batch_size,\n",
" shuffle=True,\n", " directory=train_dir,\n",
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n", " shuffle=True,\n",
" class_mode='binary')" " target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
" class_mode=\"binary\",\n",
")"
] ]
}, },
{ {
@ -164,10 +170,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n", "val_data_gen = val_image_generator.flow_from_directory(\n",
" directory=val_dir,\n", " batch_size=batch_size,\n",
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n", " directory=val_dir,\n",
" class_mode='binary')" " target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
" class_mode=\"binary\",\n",
")"
] ]
}, },
{ {
@ -196,11 +204,11 @@
"source": [ "source": [
"# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n", "# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.\n",
"def plotImages(images_arr):\n", "def plotImages(images_arr):\n",
" fig, axes = plt.subplots(1, 3, figsize=(20,20))\n", " fig, axes = plt.subplots(1, 3, figsize=(20, 20))\n",
" axes = axes.flatten()\n", " axes = axes.flatten()\n",
" for img, ax in zip( images_arr, axes):\n", " for img, ax in zip(images_arr, axes):\n",
" ax.imshow(img)\n", " ax.imshow(img)\n",
" ax.axis('off')\n", " ax.axis(\"off\")\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.show()" " plt.show()"
] ]
@ -233,10 +241,17 @@
"!rm -rf ./logs/\n", "!rm -rf ./logs/\n",
"os.mkdir(\"logs\")\n", "os.mkdir(\"logs\")\n",
"log_dir = \"logs\"\n", "log_dir = \"logs\"\n",
"#log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", "# log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", "# tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n", "# tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
"tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True, write_images=True, embeddings_freq=1, update_freq='batch')" "tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(\n",
" log_dir=log_dir,\n",
" histogram_freq=1,\n",
" write_graph=True,\n",
" write_images=True,\n",
" embeddings_freq=1,\n",
" update_freq=\"batch\",\n",
")"
] ]
}, },
{ {
@ -245,17 +260,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model = Sequential([\n", "model = Sequential(\n",
" Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),\n", " [\n",
" MaxPooling2D(),\n", " Conv2D(\n",
" Conv2D(32, 3, padding='same', activation='relu'),\n", " 16,\n",
" MaxPooling2D(),\n", " 3,\n",
" Conv2D(64, 3, padding='same', activation='relu'),\n", " padding=\"same\",\n",
" MaxPooling2D(),\n", " activation=\"relu\",\n",
" Flatten(),\n", " input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),\n",
" Dense(512, activation='relu'),\n", " ),\n",
" Dense(1, activation='sigmoid')\n", " MaxPooling2D(),\n",
"])" " Conv2D(32, 3, padding=\"same\", activation=\"relu\"),\n",
" MaxPooling2D(),\n",
" Conv2D(64, 3, padding=\"same\", activation=\"relu\"),\n",
" MaxPooling2D(),\n",
" Flatten(),\n",
" Dense(512, activation=\"relu\"),\n",
" Dense(1, activation=\"sigmoid\"),\n",
" ]\n",
")"
] ]
}, },
{ {
@ -264,14 +287,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#witzitoptimizer = 'adam'\n", "# witzitoptimizer = 'adam'\n",
"witzitoptimizer = tensorflow.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True)\n", "witzitoptimizer = tensorflow.keras.optimizers.Adam(\n",
"witzitloss = 'binary_crossentropy'\n", " learning_rate=0.001, beta_1=0.9, beta_2=0.999, amsgrad=True\n",
"#witzitmetrics = 'accuracy'\n", ")\n",
"witzitmetrics = ['accuracy']\n", "witzitloss = \"binary_crossentropy\"\n",
"model.compile(optimizer=witzitoptimizer,\n", "# witzitmetrics = 'accuracy'\n",
" loss=witzitloss,\n", "witzitmetrics = [\"accuracy\"]\n",
" metrics=[witzitmetrics])" "model.compile(optimizer=witzitoptimizer, loss=witzitloss, metrics=[witzitmetrics])"
] ]
}, },
{ {
@ -316,7 +339,7 @@
" validation_data=val_data_gen,\n", " validation_data=val_data_gen,\n",
" validation_steps=total_val // batch_size,\n", " validation_steps=total_val // batch_size,\n",
" shuffle=True,\n", " shuffle=True,\n",
" use_multiprocessing=False\n", " use_multiprocessing=False,\n",
")" ")"
] ]
}, },
@ -326,26 +349,26 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"acc = history.history['accuracy']\n", "acc = history.history[\"accuracy\"]\n",
"val_acc = history.history['val_accuracy']\n", "val_acc = history.history[\"val_accuracy\"]\n",
"\n", "\n",
"loss = history.history['loss']\n", "loss = history.history[\"loss\"]\n",
"val_loss = history.history['val_loss']\n", "val_loss = history.history[\"val_loss\"]\n",
"\n", "\n",
"epochs_range = range(epochs)\n", "epochs_range = range(epochs)\n",
"\n", "\n",
"plt.figure(figsize=(8, 8))\n", "plt.figure(figsize=(8, 8))\n",
"plt.subplot(1, 2, 1)\n", "plt.subplot(1, 2, 1)\n",
"plt.plot(epochs_range, acc, label='Training Accuracy')\n", "plt.plot(epochs_range, acc, label=\"Training Accuracy\")\n",
"plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", "plt.plot(epochs_range, val_acc, label=\"Validation Accuracy\")\n",
"plt.legend(loc='lower right')\n", "plt.legend(loc=\"lower right\")\n",
"plt.title('Training and Validation Accuracy')\n", "plt.title(\"Training and Validation Accuracy\")\n",
"\n", "\n",
"plt.subplot(1, 2, 2)\n", "plt.subplot(1, 2, 2)\n",
"plt.plot(epochs_range, loss, label='Training Loss')\n", "plt.plot(epochs_range, loss, label=\"Training Loss\")\n",
"plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.plot(epochs_range, val_loss, label=\"Validation Loss\")\n",
"plt.legend(loc='upper right')\n", "plt.legend(loc=\"upper right\")\n",
"plt.title('Training and Validation Loss')\n", "plt.title(\"Training and Validation Loss\")\n",
"plt.show()" "plt.show()"
] ]
}, },
@ -361,7 +384,7 @@
"print(train_bad_dir)\n", "print(train_bad_dir)\n",
"print(train_image_generator)\n", "print(train_image_generator)\n",
"print(train_data_gen)\n", "print(train_data_gen)\n",
"#print(sample_train_images)\n", "# print(sample_train_images)\n",
"print(history)" "print(history)"
] ]
}, },
@ -371,7 +394,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.save('/srv/witzit/data/models/al/witzit-al.h5')" "model.save(\"/srv/witzit/data/models/al/witzit-al.h5\")"
] ]
}, },
{ {
@ -380,7 +403,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.save('/srv/witzit/data/models/al/witzit-al.tf')" "model.save(\"/srv/witzit/data/models/al/witzit-al.tf\")"
] ]
}, },
{ {
@ -398,7 +421,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"plot_model(model, show_shapes=True, show_layer_names=True, expand_nested=True, dpi=72, to_file='/srv/witzit/data/models/al/plot_model.png')" "plot_model(\n",
" model,\n",
" show_shapes=True,\n",
" show_layer_names=True,\n",
" expand_nested=True,\n",
" dpi=72,\n",
" to_file=\"/srv/witzit/data/models/al/plot_model.png\",\n",
")"
] ]
}, },
{ {
@ -407,9 +437,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#from IPython.display import SVG\n", "# from IPython.display import SVG\n",
"#from tensorflow.keras.utils import model_to_dot\n", "# from tensorflow.keras.utils import model_to_dot\n",
"#SVG(model_to_dot(model).create(prog='dot', format='svg'))" "# SVG(model_to_dot(model).create(prog='dot', format='svg'))"
] ]
} }
], ],

View File

@ -1 +1 @@
black black[jupyter]