diff --git a/notebooks/wut-web-dev.ipynb b/notebooks/wut-web-dev.ipynb index 09c950f..4cd1d44 100644 --- a/notebooks/wut-web-dev.ipynb +++ b/notebooks/wut-web-dev.ipynb @@ -93,6 +93,63 @@ "model = load_model(model_file)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "def gen_image(test_data_gen,test_dir):\n", + " test_image_gen = ImageDataGenerator(rescale=1./255);\n", + " test_data_gen = test_image_gen.flow_from_directory(batch_size=1,\n", + " directory=test_dir,\n", + " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", + " shuffle=True,\n", + " class_mode='binary');\n", + " return test_data_gen" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture --no-stderr --no-stdout\n", + "def gen_image_tmp(obs_waterfall_path):\n", + " tmp_dir = tempfile.mkdtemp()\n", + " test_dir = os.path.join(tmp_dir)\n", + " os.makedirs(test_dir + '/unvetted', exist_ok=True)\n", + " shutil.copy(obs_waterfall_path, test_dir + '/unvetted/') \n", + " \n", + " img = im.open(obs_waterfall_path).resize( (100,200))\n", + " display(img)\n", + "\n", + " # XXX delete tmp dir down below\n", + " #print(test_dir)\n", + " #shutil.rmtree(test_dir) \n", + " return test_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "def obs_wutsay(test_data_gen):\n", + " prediction = model.predict(\n", + " x=test_data_gen,\n", + " verbose=0)\n", + " predictions=[]\n", + " prediction_bool = (prediction >0.8)\n", + " predictions = prediction_bool.astype(int)\n", + " \n", + " return prediction_bool" + ] + }, { "cell_type": "code", "execution_count": null, @@ -118,26 +175,12 @@ " \n", " obs_waterfall=os.path.basename(obs_waterfallurl)\n", " obs_waterfall_path = os.path.join('/srv/satnogs/download', str(obs_id), obs_waterfall)\n", + " \n", + " test_dir=gen_image_tmp(obs_waterfall_path)\n", + " test_data_gen=gen_image(obs_waterfall_path, test_dir)\n", " \n", - " tmp_dir = tempfile.mkdtemp()\n", - " test_dir = os.path.join(tmp_dir)\n", - " os.makedirs(test_dir + '/unvetted', exist_ok=True)\n", - " shutil.copy(obs_waterfall_path, test_dir + '/unvetted/')\n", - " img = im.open(obs_waterfall_path).resize( (100,200) )\n", - " test_image_generator = ImageDataGenerator(\n", - " rescale=1./255\n", - " )\n", - " test_data_gen = test_image_generator.flow_from_directory(batch_size=1,\n", - " directory=test_dir,\n", - " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", - " shuffle=True,\n", - " class_mode='binary')\n", - " prediction = model.predict(\n", - " x=test_data_gen,\n", - " verbose=0)\n", - " predictions=[]\n", - " prediction_bool = (prediction >0.8)\n", - " predictions = prediction_bool.astype(int)\n", + " prediction_bool=obs_wutsay(test_data_gen);\n", + "\n", " print()\n", " print('Observation ID: ', obs_id)\n", " print('Encoding: ', obs_transmitter_mode)\n", @@ -146,19 +189,13 @@ " rating = 'bad'\n", " else:\n", " rating = 'good'\n", - " print('wut AI rating: %s' % (rating))\n", - " \n", + " print('wut AI rating: %s' % (rating)) \n", " print()\n", " if obs_transmitter_mode == 'DUV':\n", " print(\"Using DUV training model.\")\n", " else:\n", " print(\"NOTE: wut has not been trained on\", obs_transmitter_mode, \"encodings.\")\n", - "\n", " print('https://network.satnogs.org/observations/' + str(obs_id))\n", - " print()\n", - " \n", - " display(img)\n", - " shutil.rmtree(test_dir)\n", " #!cat $obsjsonfile" ] }, @@ -168,7 +205,7 @@ "metadata": {}, "outputs": [], "source": [ - "print('Enter an Observation ID between', minobsid, 'and', maxobsid)\n", + "print('Enter an Observation ID between', minobsid, 'and', maxobsid);\n", "wutObs_slide = wg.IntText(value='1292461');\n", "wg.interact(wutObs, datObs=wutObs_slide);" ]