414 lines
9.7 KiB
Python
414 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# wut-train-cluster-fn.py
|
|
#
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# wut-train-cluster --- What U Think? SatNOGS Observation AI, training application cluster edition.
|
|
#
|
|
# https://spacecruft.org/spacecruft/satnogs-wut
|
|
#
|
|
# Based on data/train and data/val directories builds a wut.tf file.
|
|
# GPLv3+
|
|
# Built using Jupyter, Tensorflow, Keras
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
from __future__ import print_function
|
|
import os
|
|
import numpy as np
|
|
import simplejson as json
|
|
import datetime
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
import tensorflow as tf
|
|
import tensorflow.python.keras
|
|
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
|
|
from tensorflow.python.keras import optimizers
|
|
from tensorflow.python.keras import Sequential
|
|
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
|
|
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
|
|
from tensorflow.python.keras.layers import Input, concatenate
|
|
from tensorflow.python.keras.models import load_model
|
|
from tensorflow.python.keras.models import Model
|
|
from tensorflow.python.keras.preprocessing import image
|
|
from tensorflow.python.keras.preprocessing.image import img_to_array
|
|
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
|
|
from tensorflow.python.keras.preprocessing.image import load_img
|
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
get_ipython().run_line_magic('matplotlib', 'inline')
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
#from sklearn.decomposition import PCA
|
|
#from ipywidgets import interact, interactive, fixed, interact_manual
|
|
#import ipywidgets as widgets
|
|
#from IPython.display import display, Image
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
print('tf {}'.format(tf.__version__))
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
os.environ["TF_CONFIG"] = json.dumps({
|
|
"cluster": {
|
|
"worker": ["ml1-int:2222", "ml2-int:2222", "ml3-int:2222", "ml4-int:2222", "ml5-int:2222" ]
|
|
},
|
|
"task": {"type": "worker", "index": 0 },
|
|
"num_workers": 5
|
|
})
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
IMG_HEIGHT = 416
|
|
IMG_WIDTH= 804
|
|
batch_size = 32
|
|
epochs = 1
|
|
# Full size, machine barfs probably needs more RAM
|
|
#IMG_HEIGHT = 832
|
|
#IMG_WIDTH = 1606
|
|
# Good results
|
|
#batch_size = 128
|
|
#epochs = 6
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
tf.keras.backend.clear_session()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
options = tf.data.Options()
|
|
#options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
|
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA
|
|
# XXX
|
|
#dataset = dataset.with_options(options)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
|
tf.distribute.experimental.CollectiveCommunication.RING)
|
|
|
|
#mirrored_strategy = tf.distribute.MirroredStrategy(
|
|
# cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
root_data_dir = ('/srv/satnogs')
|
|
train_dir = os.path.join(root_data_dir, 'data/', 'train')
|
|
val_dir = os.path.join(root_data_dir,'data/', 'val')
|
|
train_good_dir = os.path.join(train_dir, 'good')
|
|
train_bad_dir = os.path.join(train_dir, 'bad')
|
|
val_good_dir = os.path.join(val_dir, 'good')
|
|
val_bad_dir = os.path.join(val_dir, 'bad')
|
|
num_train_good = len(os.listdir(train_good_dir))
|
|
num_train_bad = len(os.listdir(train_bad_dir))
|
|
num_val_good = len(os.listdir(val_good_dir))
|
|
num_val_bad = len(os.listdir(val_bad_dir))
|
|
total_train = num_train_good + num_train_bad
|
|
total_val = num_val_good + num_val_bad
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
print('total training good images:', num_train_good)
|
|
print('total training bad images:', num_train_bad)
|
|
print("--")
|
|
print("Total training images:", total_train)
|
|
print('total validation good images:', num_val_good)
|
|
print('total validation bad images:', num_val_bad)
|
|
print("--")
|
|
print("Total validation images:", total_val)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
print("--")
|
|
print("Reduce training and validation set when testing")
|
|
total_train = 100
|
|
total_val = 100
|
|
print("Reduced training images:", total_train)
|
|
print("Reduced validation images:", total_val)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
train_image_generator = ImageDataGenerator(
|
|
rescale=1./255
|
|
)
|
|
val_image_generator = ImageDataGenerator(
|
|
rescale=1./255
|
|
)
|
|
#train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
|
|
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
|
|
directory=train_dir,
|
|
shuffle=True,
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
class_mode='binary')
|
|
#val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
|
|
val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,
|
|
directory=val_dir,
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
class_mode='binary')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#train_dist_dataset = strategy.experimental_distribute_dataset()
|
|
#val_dist_dataset = strategy.experimental_distribute_dataset()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
sample_train_images, _ = next(train_data_gen)
|
|
sample_val_images, _ = next(val_data_gen)
|
|
# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.
|
|
def plotImages(images_arr):
|
|
fig, axes = plt.subplots(1, 3, figsize=(20,20))
|
|
axes = axes.flatten()
|
|
for img, ax in zip( images_arr, axes):
|
|
ax.imshow(img)
|
|
ax.axis('off')
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
plotImages(sample_train_images[0:3])
|
|
plotImages(sample_val_images[0:3])
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
get_ipython().run_line_magic('load_ext', 'tensorboard')
|
|
get_ipython().system('rm -rf ./clusterlogs/')
|
|
#log_dir="clusterlogs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
log_dir="clusterlogs"
|
|
#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
|
|
tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)
|
|
#%tensorboard --logdir clusterlogs --port 6006
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
strategy.num_replicas_in_sync
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
## Compute global batch size using number of replicas.
|
|
#GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
|
|
BATCH_SIZE_PER_REPLICA = 8
|
|
print("BATCH_SIZE_PER_REPLICA", BATCH_SIZE_PER_REPLICA)
|
|
print("strategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
|
|
global_batch_size = (BATCH_SIZE_PER_REPLICA *
|
|
strategy.num_replicas_in_sync)
|
|
print("global_batch_size", global_batch_size)
|
|
print("total_train", total_train)
|
|
print("total_val ", total_val)
|
|
print("batch_size", batch_size)
|
|
print("total_train // batch_size", total_train // batch_size)
|
|
print("total_val // batch_size", total_val // batch_size)
|
|
#.batch(global_batch_size)
|
|
#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)
|
|
#dataset = dataset.batch(global_batch_size)
|
|
#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
def get_uncompiled_model():
|
|
model = Sequential([
|
|
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
|
|
MaxPooling2D(),
|
|
Conv2D(32, 3, padding='same', activation='relu'),
|
|
MaxPooling2D(),
|
|
Conv2D(64, 3, padding='same', activation='relu'),
|
|
MaxPooling2D(),
|
|
Flatten(),
|
|
Dense(512, activation='relu'),
|
|
Dense(1, activation='sigmoid')
|
|
])
|
|
return model
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#get_uncompiled_model()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
def get_compiled_model():
|
|
model = get_uncompiled_model()
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'])
|
|
return model
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# Create a checkpoint directory to store the checkpoints.
|
|
#checkpoint_dir = './training_checkpoints'
|
|
#checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='tmp/keras-ckpt')]
|
|
#callbacks=[tensorboard_callback,callbacks]
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#def get_fit_model():
|
|
# model = get_compiled_model()
|
|
# model.fit(
|
|
# train_data_gen,
|
|
# steps_per_epoch=total_train // batch_size,
|
|
# epochs=epochs,
|
|
# validation_data=val_data_gen,
|
|
# validation_steps=total_val // batch_size,
|
|
# verbose=2
|
|
# )
|
|
#return model
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
with strategy.scope():
|
|
model = get_compiled_model()
|
|
history = model.fit(
|
|
train_data_gen,
|
|
steps_per_epoch=total_train // batch_size,
|
|
epochs=epochs,
|
|
validation_data=val_data_gen,
|
|
validation_steps=total_val // batch_size,
|
|
verbose=2
|
|
).batch(global_batch_size)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#model.summary()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
print("TRAINING info")
|
|
print(train_dir)
|
|
print(train_good_dir)
|
|
print(train_bad_dir)
|
|
print(train_image_generator)
|
|
print(train_data_gen)
|
|
#print(sample_train_images)
|
|
#print(history)
|
|
#model.to_json()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
#history = model.fit(X, y, batch_size=32, epochs=40, validation_split=0.1)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
model.save('data/models/FOO/wut-train-cluster2.tf')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
model.save('data/models/FOO/wut-train-cluster2.h5')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
model.save_weights('data/models/FOO/wut-weights-train-cluster2.tf')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
model.save_weights('data/models/FOO/wut-weights-train-cluster2.h5')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
acc = history.history['accuracy']
|
|
val_acc = history.history['val_accuracy']
|
|
loss = history.history['loss']
|
|
val_loss = history.history['val_loss']
|
|
epochs_range = range(epochs)
|
|
plt.figure(figsize=(8, 8))
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(epochs_range, acc, label='Training Accuracy')
|
|
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
|
|
plt.legend(loc='lower right')
|
|
plt.title('Training and Validation Accuracy')
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(epochs_range, loss, label='Training Loss')
|
|
plt.plot(epochs_range, val_loss, label='Validation Loss')
|
|
plt.legend(loc='upper right')
|
|
plt.title('Training and Validation Loss')
|
|
plt.show()
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# The End
|
|
|