still meh fit() dist
parent
cc31af1062
commit
68b2fc2730
|
@ -117,6 +117,15 @@
|
||||||
"tf.keras.backend.clear_session()"
|
"tf.keras.backend.clear_session()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"options = tf.data.Options()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
@ -210,17 +219,29 @@
|
||||||
"val_image_generator = ImageDataGenerator(\n",
|
"val_image_generator = ImageDataGenerator(\n",
|
||||||
" rescale=1./255\n",
|
" rescale=1./255\n",
|
||||||
")\n",
|
")\n",
|
||||||
"train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
"#train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||||
|
"train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
|
||||||
" directory=train_dir,\n",
|
" directory=train_dir,\n",
|
||||||
" shuffle=True,\n",
|
" shuffle=True,\n",
|
||||||
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
||||||
" class_mode='binary')\n",
|
" class_mode='binary')\n",
|
||||||
"val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
"#val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
|
||||||
|
"val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
|
||||||
" directory=val_dir,\n",
|
" directory=val_dir,\n",
|
||||||
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
|
||||||
" class_mode='binary')"
|
" class_mode='binary')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#train_dist_dataset = strategy.experimental_distribute_dataset()\n",
|
||||||
|
"#val_dist_dataset = strategy.experimental_distribute_dataset()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
@ -264,16 +285,24 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"#strategy.num_replicas_in_sync\n",
|
"strategy.num_replicas_in_sync"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
"## Compute global batch size using number of replicas.\n",
|
"## Compute global batch size using number of replicas.\n",
|
||||||
"#BATCH_SIZE_PER_REPLICA = 5\n",
|
"BATCH_SIZE_PER_REPLICA = 5\n",
|
||||||
"#print(BATCH_SIZE_PER_REPLICA)\n",
|
"print(BATCH_SIZE_PER_REPLICA)\n",
|
||||||
"#global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
|
"global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
|
||||||
"# strategy.num_replicas_in_sync)\n",
|
" strategy.num_replicas_in_sync)\n",
|
||||||
"#print(global_batch_size)\n",
|
"print(global_batch_size)\n",
|
||||||
"#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
||||||
"#dataset = dataset.batch(global_batch_size)\n",
|
"dataset = dataset.batch(global_batch_size)\n",
|
||||||
"#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
"LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -338,6 +367,17 @@
|
||||||
"#model = get_compiled_model()"
|
"#model = get_compiled_model()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create a checkpoint directory to store the checkpoints.\n",
|
||||||
|
"checkpoint_dir = './training_checkpoints'\n",
|
||||||
|
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|
|
@ -41,9 +41,15 @@ IMG_WIDTH= 804
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
epochs = 4
|
epochs = 4
|
||||||
|
|
||||||
|
BUFFER_SIZE = 10000
|
||||||
|
NUM_WORKERS = 6
|
||||||
|
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
|
||||||
|
|
||||||
# XXX
|
# XXX
|
||||||
#tf.keras.backend.clear_session()
|
#tf.keras.backend.clear_session()
|
||||||
|
|
||||||
|
options = tf.data.Options()
|
||||||
|
|
||||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
||||||
tf.distribute.experimental.CollectiveCommunication.RING)
|
tf.distribute.experimental.CollectiveCommunication.RING)
|
||||||
|
|
||||||
|
@ -112,31 +118,6 @@ def get_compiled_model():
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
#def get_fit_model():
|
|
||||||
# model = get_compiled_model()
|
|
||||||
# model.fit(
|
|
||||||
# train_data_gen,
|
|
||||||
# steps_per_epoch=total_train // batch_size,
|
|
||||||
# epochs=epochs,
|
|
||||||
# validation_data=val_data_gen,
|
|
||||||
# validation_steps=total_val // batch_size,
|
|
||||||
# verbose=2
|
|
||||||
# )
|
|
||||||
# return model
|
|
||||||
|
|
||||||
#with strategy.scope():
|
|
||||||
# get_uncompiled_model()
|
|
||||||
#with strategy.scope():
|
|
||||||
# get_compiled_model()
|
|
||||||
#with strategy.scope():
|
|
||||||
# get_fit_model()
|
|
||||||
|
|
||||||
#multi_worker_model = get_compiled_model()
|
|
||||||
#multi_worker_model.fit(
|
|
||||||
# x=train_data_gen,
|
|
||||||
# epochs=epochs,
|
|
||||||
# steps_per_epoch=total_train // batch_size
|
|
||||||
# )
|
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
model = get_compiled_model()
|
model = get_compiled_model()
|
||||||
|
|
Loading…
Reference in New Issue