still meh fit() dist
parent
cc31af1062
commit
68b2fc2730
|
@ -117,6 +117,15 @@
|
|||
"tf.keras.backend.clear_session()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"options = tf.data.Options()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@ -210,17 +219,29 @@
|
|||
"val_image_generator = ImageDataGenerator(\n",
|
||||
" rescale=1./255\n",
|
||||
")\n",
|
||||
"train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||
"#train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||
"train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
|
||||
" directory=train_dir,\n",
|
||||
" shuffle=True,\n",
|
||||
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
||||
" class_mode='binary')\n",
|
||||
"val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||
"#val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||
"val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
|
||||
" directory=val_dir,\n",
|
||||
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
||||
" class_mode='binary')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#train_dist_dataset = strategy.experimental_distribute_dataset()\n",
|
||||
"#val_dist_dataset = strategy.experimental_distribute_dataset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
@ -264,16 +285,24 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#strategy.num_replicas_in_sync\n",
|
||||
"strategy.num_replicas_in_sync"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Compute global batch size using number of replicas.\n",
|
||||
"#BATCH_SIZE_PER_REPLICA = 5\n",
|
||||
"#print(BATCH_SIZE_PER_REPLICA)\n",
|
||||
"#global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
|
||||
"# strategy.num_replicas_in_sync)\n",
|
||||
"#print(global_batch_size)\n",
|
||||
"#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
||||
"#dataset = dataset.batch(global_batch_size)\n",
|
||||
"#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
||||
"BATCH_SIZE_PER_REPLICA = 5\n",
|
||||
"print(BATCH_SIZE_PER_REPLICA)\n",
|
||||
"global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
|
||||
" strategy.num_replicas_in_sync)\n",
|
||||
"print(global_batch_size)\n",
|
||||
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
||||
"dataset = dataset.batch(global_batch_size)\n",
|
||||
"LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -338,6 +367,17 @@
|
|||
"#model = get_compiled_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a checkpoint directory to store the checkpoints.\n",
|
||||
"checkpoint_dir = './training_checkpoints'\n",
|
||||
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
|
|
@ -41,9 +41,15 @@ IMG_WIDTH= 804
|
|||
batch_size = 32
|
||||
epochs = 4
|
||||
|
||||
BUFFER_SIZE = 10000
|
||||
NUM_WORKERS = 6
|
||||
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
|
||||
|
||||
# XXX
|
||||
#tf.keras.backend.clear_session()
|
||||
|
||||
options = tf.data.Options()
|
||||
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
||||
tf.distribute.experimental.CollectiveCommunication.RING)
|
||||
|
||||
|
@ -112,31 +118,6 @@ def get_compiled_model():
|
|||
metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
#def get_fit_model():
|
||||
# model = get_compiled_model()
|
||||
# model.fit(
|
||||
# train_data_gen,
|
||||
# steps_per_epoch=total_train // batch_size,
|
||||
# epochs=epochs,
|
||||
# validation_data=val_data_gen,
|
||||
# validation_steps=total_val // batch_size,
|
||||
# verbose=2
|
||||
# )
|
||||
# return model
|
||||
|
||||
#with strategy.scope():
|
||||
# get_uncompiled_model()
|
||||
#with strategy.scope():
|
||||
# get_compiled_model()
|
||||
#with strategy.scope():
|
||||
# get_fit_model()
|
||||
|
||||
#multi_worker_model = get_compiled_model()
|
||||
#multi_worker_model.fit(
|
||||
# x=train_data_gen,
|
||||
# epochs=epochs,
|
||||
# steps_per_epoch=total_train // batch_size
|
||||
# )
|
||||
|
||||
with strategy.scope():
|
||||
model = get_compiled_model()
|
||||
|
|
Loading…
Reference in New Issue