wut-train-cluster-fn.py from jupyter export...
parent
b945d4ab06
commit
9fa145ed19
|
@ -48,7 +48,8 @@
|
|||
"from tensorflow.python.keras.preprocessing import image\n",
|
||||
"from tensorflow.python.keras.preprocessing.image import img_to_array\n",
|
||||
"from tensorflow.python.keras.preprocessing.image import ImageDataGenerator\n",
|
||||
"from tensorflow.python.keras.preprocessing.image import load_img"
|
||||
"from tensorflow.python.keras.preprocessing.image import load_img\n",
|
||||
"from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -60,10 +61,10 @@
|
|||
"%matplotlib inline\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import seaborn as sns\n",
|
||||
"from sklearn.decomposition import PCA\n",
|
||||
"from ipywidgets import interact, interactive, fixed, interact_manual\n",
|
||||
"import ipywidgets as widgets\n",
|
||||
"from IPython.display import display, Image"
|
||||
"#from sklearn.decomposition import PCA\n",
|
||||
"#from ipywidgets import interact, interactive, fixed, interact_manual\n",
|
||||
"#import ipywidgets as widgets\n",
|
||||
"#from IPython.display import display, Image"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -86,7 +87,7 @@
|
|||
" \"worker\": [ \"ml0-int:2222\", \"ml1-int:2222\", \"ml2-int:2222\", \"ml3-int:2222\", \"ml4-int:2222\", \"ml5-int:2222\" ]\n",
|
||||
" },\n",
|
||||
" \"task\": {\"type\": \"worker\", \"index\": 0 },\n",
|
||||
" \"num_workers\": 40\n",
|
||||
" \"num_workers\": 6\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
|
@ -99,7 +100,7 @@
|
|||
"IMG_HEIGHT = 416\n",
|
||||
"IMG_WIDTH= 804\n",
|
||||
"batch_size = 32\n",
|
||||
"epochs = 4\n",
|
||||
"epochs = 1\n",
|
||||
"# Full size, machine barfs probably needs more RAM\n",
|
||||
"#IMG_HEIGHT = 832\n",
|
||||
"#IMG_WIDTH = 1606\n",
|
||||
|
@ -123,7 +124,11 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"options = tf.data.Options()"
|
||||
"options = tf.data.Options()\n",
|
||||
"#options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF\n",
|
||||
"options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA\n",
|
||||
"# XXX\n",
|
||||
"#dataset = dataset.with_options(options)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -133,7 +138,10 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(\n",
|
||||
" tf.distribute.experimental.CollectiveCommunication.RING)"
|
||||
" tf.distribute.experimental.CollectiveCommunication.RING)\n",
|
||||
"\n",
|
||||
"#mirrored_strategy = tf.distribute.MirroredStrategy(\n",
|
||||
"# cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -142,8 +150,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dir = os.path.join('data/', 'train')\n",
|
||||
"val_dir = os.path.join('data/', 'val')\n",
|
||||
"root_data_dir = ('/srv/satnogs')\n",
|
||||
"train_dir = os.path.join(root_data_dir, 'data/', 'train')\n",
|
||||
"val_dir = os.path.join(root_data_dir,'data/', 'val')\n",
|
||||
"train_good_dir = os.path.join(train_dir, 'good')\n",
|
||||
"train_bad_dir = os.path.join(train_dir, 'bad')\n",
|
||||
"val_good_dir = os.path.join(val_dir, 'good')\n",
|
||||
|
@ -180,8 +189,8 @@
|
|||
"source": [
|
||||
"print(\"--\")\n",
|
||||
"print(\"Reduce training and validation set when testing\")\n",
|
||||
"#total_train = 16\n",
|
||||
"#total_val = 16\n",
|
||||
"total_train = 100\n",
|
||||
"total_val = 100\n",
|
||||
"print(\"Reduced training images:\", total_train)\n",
|
||||
"print(\"Reduced validation images:\", total_val)"
|
||||
]
|
||||
|
@ -255,7 +264,7 @@
|
|||
"log_dir=\"clusterlogs\"\n",
|
||||
"#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
|
||||
"tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)\n",
|
||||
"%tensorboard --logdir clusterlogs --port 6006"
|
||||
"#%tensorboard --logdir clusterlogs --port 6006"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -274,14 +283,22 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"## Compute global batch size using number of replicas.\n",
|
||||
"BATCH_SIZE_PER_REPLICA = 5\n",
|
||||
"print(BATCH_SIZE_PER_REPLICA)\n",
|
||||
"#GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS\n",
|
||||
"BATCH_SIZE_PER_REPLICA = 8\n",
|
||||
"print(\"BATCH_SIZE_PER_REPLICA\", BATCH_SIZE_PER_REPLICA)\n",
|
||||
"print(\"strategy.num_replicas_in_sync\", strategy.num_replicas_in_sync)\n",
|
||||
"global_batch_size = (BATCH_SIZE_PER_REPLICA *\n",
|
||||
" strategy.num_replicas_in_sync)\n",
|
||||
"print(global_batch_size)\n",
|
||||
"dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
||||
"dataset = dataset.batch(global_batch_size)\n",
|
||||
"LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
||||
"print(\"global_batch_size\", global_batch_size)\n",
|
||||
"print(\"total_train\", total_train)\n",
|
||||
"print(\"total_val \", total_val)\n",
|
||||
"print(\"batch_size\", batch_size)\n",
|
||||
"print(\"total_train // batch_size\", total_train // batch_size)\n",
|
||||
"print(\"total_val // batch_size\", total_val // batch_size)\n",
|
||||
"#.batch(global_batch_size)\n",
|
||||
"#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)\n",
|
||||
"#dataset = dataset.batch(global_batch_size)\n",
|
||||
"#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -385,14 +402,14 @@
|
|||
"source": [
|
||||
"with strategy.scope():\n",
|
||||
" model = get_compiled_model()\n",
|
||||
" model.fit(\n",
|
||||
" history = model.fit(\n",
|
||||
" train_data_gen,\n",
|
||||
" steps_per_epoch=total_train // batch_size,\n",
|
||||
" epochs=epochs,\n",
|
||||
" validation_data=val_data_gen,\n",
|
||||
" validation_steps=total_val // batch_size,\n",
|
||||
" verbose=2\n",
|
||||
" )"
|
||||
" ).batch(global_batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -50,10 +50,10 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
|
|||
get_ipython().run_line_magic('matplotlib', 'inline')
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.decomposition import PCA
|
||||
from ipywidgets import interact, interactive, fixed, interact_manual
|
||||
import ipywidgets as widgets
|
||||
from IPython.display import display, Image
|
||||
#from sklearn.decomposition import PCA
|
||||
#from ipywidgets import interact, interactive, fixed, interact_manual
|
||||
#import ipywidgets as widgets
|
||||
#from IPython.display import display, Image
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
|
Loading…
Reference in New Issue