still meh fit() dist

master
ml server 2020-01-18 21:19:56 -07:00
parent cc31af1062
commit 68b2fc2730
2 changed files with 57 additions and 36 deletions

View File

@ -117,6 +117,15 @@
"tf.keras.backend.clear_session()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"options = tf.data.Options()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -210,17 +219,29 @@
"val_image_generator = ImageDataGenerator(\n",
" rescale=1./255\n",
")\n",
"train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
"#train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,\n",
"train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
" directory=train_dir,\n",
" shuffle=True,\n",
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
" class_mode='binary')\n",
"val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
"#val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,\n",
"val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,\n",
" directory=val_dir,\n",
" target_size=(IMG_HEIGHT, IMG_WIDTH),\n",
" class_mode='binary')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#train_dist_dataset = strategy.experimental_distribute_dataset()\n",
"#val_dist_dataset = strategy.experimental_distribute_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -264,16 +285,24 @@
"metadata": {},
"outputs": [],
"source": [
"#strategy.num_replicas_in_sync\n",
"strategy.num_replicas_in_sync"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Compute global batch size using number of replicas.\n",
"#BATCH_SIZE_PER_REPLICA = 5\n",
"#print(BATCH_SIZE_PER_REPLICA)\n",
"#global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
"# strategy.num_replicas_in_sync)\n",
"#print(global_batch_size)\n",
"#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
"#dataset = dataset.batch(global_batch_size)\n",
"#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
"BATCH_SIZE_PER_REPLICA = 5\n",
"print(BATCH_SIZE_PER_REPLICA)\n",
"global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
" strategy.num_replicas_in_sync)\n",
"print(global_batch_size)\n",
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
"dataset = dataset.batch(global_batch_size)\n",
"LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
]
},
{
@ -338,6 +367,17 @@
"#model = get_compiled_model()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a checkpoint directory to store the checkpoints.\n",
"checkpoint_dir = './training_checkpoints'\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@ -41,9 +41,15 @@ IMG_WIDTH= 804
batch_size = 32
epochs = 4
BUFFER_SIZE = 10000
NUM_WORKERS = 6
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
# XXX
#tf.keras.backend.clear_session()
options = tf.data.Options()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
tf.distribute.experimental.CollectiveCommunication.RING)
@ -112,31 +118,6 @@ def get_compiled_model():
metrics=['accuracy'])
return model
#def get_fit_model():
# model = get_compiled_model()
# model.fit(
# train_data_gen,
# steps_per_epoch=total_train // batch_size,
# epochs=epochs,
# validation_data=val_data_gen,
# validation_steps=total_val // batch_size,
# verbose=2
# )
# return model
#with strategy.scope():
# get_uncompiled_model()
#with strategy.scope():
# get_compiled_model()
#with strategy.scope():
# get_fit_model()
#multi_worker_model = get_compiled_model()
#multi_worker_model.fit(
# x=train_data_gen,
# epochs=epochs,
# steps_per_epoch=total_train // batch_size
# )
with strategy.scope():
model = get_compiled_model()