From 68b2fc2730c4e6ed200355820fbd337c5c9c9c4f Mon Sep 17 00:00:00 2001 From: ml server Date: Sat, 18 Jan 2020 21:19:56 -0700 Subject: [PATCH] still meh fit() dist --- jupyter/wut-train-cluster-fn.ipynb | 62 ++++++++++++++++++++++++------ wut-train-cluster-fn.py | 31 +++------------ 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/jupyter/wut-train-cluster-fn.ipynb b/jupyter/wut-train-cluster-fn.ipynb index 2ad5653..7f3296a 100644 --- a/jupyter/wut-train-cluster-fn.ipynb +++ b/jupyter/wut-train-cluster-fn.ipynb @@ -117,6 +117,15 @@ "tf.keras.backend.clear_session()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "options = tf.data.Options()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -210,17 +219,29 @@ "val_image_generator = ImageDataGenerator(\n", " rescale=1./255\n", ")\n", - "train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n", + "#train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n", + "train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n", " directory=train_dir,\n", " shuffle=True,\n", " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", " class_mode='binary')\n", - "val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n", + "#val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n", + "val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n", " directory=val_dir,\n", " target_size=(IMG_HEIGHT, IMG_WIDTH),\n", " class_mode='binary')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#train_dist_dataset = strategy.experimental_distribute_dataset()\n", + "#val_dist_dataset = strategy.experimental_distribute_dataset()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -264,16 +285,24 @@ "metadata": {}, "outputs": [], "source": [ - "#strategy.num_replicas_in_sync\n", + "strategy.num_replicas_in_sync" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "## Compute global batch size using number of replicas.\n", - "#BATCH_SIZE_PER_REPLICA = 5\n", - "#print(BATCH_SIZE_PER_REPLICA)\n", - "#global_batch_size = (BATCH_SIZE_PER_REPLICA *\n", - "# strategy.num_replicas_in_sync)\n", - "#print(global_batch_size)\n", - "#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n", - "#dataset = dataset.batch(global_batch_size)\n", - "#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}" + "BATCH_SIZE_PER_REPLICA = 5\n", + "print(BATCH_SIZE_PER_REPLICA)\n", + "global_batch_size = (BATCH_SIZE_PER_REPLICA *\n", + " strategy.num_replicas_in_sync)\n", + "print(global_batch_size)\n", + "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n", + "dataset = dataset.batch(global_batch_size)\n", + "LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}" ] }, { @@ -338,6 +367,17 @@ "#model = get_compiled_model()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a checkpoint directory to store the checkpoints.\n", + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/wut-train-cluster-fn.py b/wut-train-cluster-fn.py index bea3a56..db37826 100644 --- a/wut-train-cluster-fn.py +++ b/wut-train-cluster-fn.py @@ -41,9 +41,15 @@ IMG_WIDTH= 804 batch_size = 32 epochs = 4 +BUFFER_SIZE = 10000 +NUM_WORKERS = 6 +GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS + # XXX #tf.keras.backend.clear_session() +options = tf.data.Options() + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( tf.distribute.experimental.CollectiveCommunication.RING) @@ -112,31 +118,6 @@ def get_compiled_model(): metrics=['accuracy']) return model -#def get_fit_model(): -# model = get_compiled_model() -# model.fit( -# train_data_gen, -# steps_per_epoch=total_train // batch_size, -# epochs=epochs, -# validation_data=val_data_gen, -# validation_steps=total_val // batch_size, -# verbose=2 -# ) -# return model - -#with strategy.scope(): -# get_uncompiled_model() -#with strategy.scope(): -# get_compiled_model() -#with strategy.scope(): -# get_fit_model() - -#multi_worker_model = get_compiled_model() -#multi_worker_model.fit( -# x=train_data_gen, -# epochs=epochs, -# steps_per_epoch=total_train // batch_size -# ) with strategy.scope(): model = get_compiled_model()