From 6a6c008fbb202edc9854c2ad93ae4613b2320de2 Mon Sep 17 00:00:00 2001 From: ml server Date: Wed, 15 Jan 2020 15:05:14 -0700 Subject: [PATCH] training with result, sorta --- jupyter/wut-ml.ipynb | 402 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 378 insertions(+), 24 deletions(-) diff --git a/jupyter/wut-ml.ipynb b/jupyter/wut-ml.ipynb index 22a1027..017031a 100644 --- a/jupyter/wut-ml.ipynb +++ b/jupyter/wut-ml.ipynb @@ -227,7 +227,8 @@ "source": [ "train_dir = os.path.join('data/', 'train')\n", "val_dir = os.path.join('data/', 'val')\n", - "test_dir = os.path.join('data/', 'test/unvetted')" + "#test_dir = os.path.join('data/', 'test/unvetted')\n", + "test_dir = os.path.join('data/', 'test')" ] }, { @@ -259,7 +260,7 @@ "outputs": [], "source": [ "#data/test/unvetted/waterfall.png\n", - "test_img = os.path.join(test_dir, 'waterfall.png')" + "#test_img = os.path.join(test_dir, 'waterfall.png')" ] }, { @@ -338,8 +339,8 @@ "outputs": [], "source": [ "print(\"Reduce training and validation set\")\n", - "total_train = 50\n", - "total_val = 50\n", + "total_train = 100\n", + "total_val = 100\n", "print(\"Train =\")\n", "print(total_train)\n", "print(\"Validation =\")\n", @@ -352,8 +353,9 @@ "metadata": {}, "outputs": [], "source": [ - "print(test_img)\n", - "display(Image(test_img))" + "#print(test_img)\n", + "#test_img = os.path.join(test_dir, 'waterfall.png')\n", + "display(Image(os.path.join(test_dir, 'unvetted/waterfall.png')))" ] }, { @@ -362,7 +364,7 @@ "metadata": {}, "outputs": [], "source": [ - "batch_size = 16\n", + "batch_size = 64\n", "epochs = 2" ] }, @@ -402,6 +404,17 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_image_generator = ImageDataGenerator(\n", + " rescale=1./255\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -427,6 +440,27 @@ " class_mode='binary')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(test_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_data_gen = test_image_generator.flow_from_directory(batch_size=batch_size,\n", + " directory=test_dir,\n", + " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", + " class_mode='binary')" + ] + }, { "cell_type": "code", "execution_count": null, @@ -451,7 +485,7 @@ "metadata": {}, "outputs": [], "source": [ - "#sample_test_images, _ = next(test_data_gen)" + "sample_test_images, _ = next(test_data_gen)" ] }, { @@ -477,7 +511,7 @@ "metadata": {}, "outputs": [], "source": [ - "plotImages(sample_train_images[:3])" + "plotImages(sample_train_images[0:3])" ] }, { @@ -486,7 +520,16 @@ "metadata": {}, "outputs": [], "source": [ - "plotImages(sample_val_images[:3])" + "plotImages(sample_val_images[0:3])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotImages(sample_test_images[0:1])" ] }, { @@ -564,7 +607,6 @@ "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", - "''\n", "epochs_range = range(epochs)\n", "\n", "plt.figure(figsize=(8, 8))\n", @@ -587,7 +629,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "print(\"TRAINING\")" + ] }, { "cell_type": "code", @@ -632,7 +676,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(sample_train_images,)" + "print(sample_train_images)" ] }, { @@ -650,7 +694,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_generator.reset()" + "#test_generator.reset()" ] }, { @@ -659,7 +703,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_datagen=ImageDataGenerator(rescale=1./255.)" + "#test_datagen=ImageDataGenerator(rescale=1./255.)" ] }, { @@ -668,11 +712,11 @@ "metadata": {}, "outputs": [], "source": [ - "test_generator=test_datagen.flow_from_directory(\n", - " directory=\"data/test/\",\n", - " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", - " class_mode='binary'\n", - ")" + "#test_generator=test_datagen.flow_from_directory(\n", + "# directory=\"data/test/\",\n", + "# target_size=(IMG_HEIGHT, IMG_WIDTH),\n", + "# class_mode='binary'\n", + "#)" ] }, { @@ -691,7 +735,7 @@ "metadata": {}, "outputs": [], "source": [ - "pred=model.predict_generator(test_generator,\n", + "pred=model.predict_generator(test_data_gen,\n", "steps=4,\n", "verbose=1)" ] @@ -703,7 +747,7 @@ "outputs": [], "source": [ "prediction = model.predict(\n", - " x=test_generator,\n", + " x=test_data_gen,\n", " verbose=2\n", ")\n", "print(\"end predict\")" @@ -721,7 +765,9 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "predictions=[]" + ] }, { "cell_type": "code", @@ -733,6 +779,313 @@ "print(prediction)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#prediction_bool = (prediction >0.5)\n", + "prediction_bool = (prediction == 1)\n", + "print(prediction_bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = prediction_bool.astype(int)\n", + "print(predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# output type 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = prediction_bool.astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "columns=[\"bad\", \"good\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#columns should be the same order of y_col\n", + "#results=pd.DataFrame(predictions, columns=columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#columns should be the same order of y_col\n", + "results=(predictions, columns)\n", + "print(results[0])\n", + "print(results[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#results[\"Filenames\"]=test_gen.filenames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ordered_cols=[\"Filenames\"]+columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#esults=results[ordered_cols]#To get the same column order" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#results.to_csv(\"results.csv\",index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# output type 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels = train_data_gen.class_indices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels = dict((v,k) for k,v in labels.items())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "columns=[\"bad\", \"good\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for row in prediction_bool:\n", + " l=[]\n", + " for index,cls in enumerate(row):\n", + " if cls: \n", + " l.append(labels[index])\n", + " predictions.append(\",\".join(l))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filenames=test_generator.filenames" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results=pd.DataFrame({\"Filename\":filenames,\n", + " \"Predictions\":predictions})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#results.to_csv(\"results.csv\",index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -740,7 +1093,8 @@ "outputs": [], "source": [ "# Make final prediction\n", - "if prediction[0][0] == 1:\n", + "#if prediction[0][0] == 0:\n", + "if prediction == [[0]]:\n", " rating = 'bad'\n", "else:\n", " rating = 'good'\n",