satnogs-wut/wut-worker.py

256 lines
9.3 KiB
Python

#!/usr/bin/env python3
#
# wut-worker.py
#
# wut --- What U Think? SatNOGS Observation AI, training application cluster edition.
#
# https://spacecruft.org/spacecruft/satnogs-wut
#
# Based on data/train and data/val directories builds a wut.tf file.
# GPLv3+
# Built using Jupyter, Tensorflow, Keras
from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import print_function
import os
import numpy as np
import simplejson as json
import datetime
import tensorflow as tf
import tensorflow.python.keras
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
from tensorflow.python.keras.layers import Input, concatenate
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.preprocessing.image import img_to_array
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.preprocessing.image import load_img
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
import seaborn as sns
#from sklearn.decomposition import PCA
#from ipywidgets import interact, interactive, fixed, interact_manual
#import ipywidgets as widgets
#from IPython.display import display, Image
print('tf {}'.format(tf.__version__))
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["ml1-int:2222", "ml2-int:2222", "ml3-int:2222", "ml4-int:2222", "ml5-int:2222" ]
},
"task": {"type": "worker", "index": 0 },
"num_workers": 5
})
IMG_HEIGHT = 416
IMG_WIDTH= 804
batch_size = 32
epochs = 1
# Full size, machine barfs probably needs more RAM
#IMG_HEIGHT = 832
#IMG_WIDTH = 1606
# Good results
#batch_size = 128
#epochs = 6
tf.keras.backend.clear_session()
options = tf.data.Options()
#options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA
# XXX
#dataset = dataset.with_options(options)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
tf.distribute.experimental.CollectiveCommunication.RING)
#mirrored_strategy = tf.distribute.MirroredStrategy(
# cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
root_data_dir = ('/srv/satnogs')
train_dir = os.path.join(root_data_dir, 'data/', 'train')
val_dir = os.path.join(root_data_dir,'data/', 'val')
train_good_dir = os.path.join(train_dir, 'good')
train_bad_dir = os.path.join(train_dir, 'bad')
val_good_dir = os.path.join(val_dir, 'good')
val_bad_dir = os.path.join(val_dir, 'bad')
num_train_good = len(os.listdir(train_good_dir))
num_train_bad = len(os.listdir(train_bad_dir))
num_val_good = len(os.listdir(val_good_dir))
num_val_bad = len(os.listdir(val_bad_dir))
total_train = num_train_good + num_train_bad
total_val = num_val_good + num_val_bad
print('total training good images:', num_train_good)
print('total training bad images:', num_train_bad)
print("--")
print("Total training images:", total_train)
print('total validation good images:', num_val_good)
print('total validation bad images:', num_val_bad)
print("--")
print("Total validation images:", total_val)
print("--")
print("Reduce training and validation set when testing")
total_train = 100
total_val = 100
print("Reduced training images:", total_train)
print("Reduced validation images:", total_val)
train_image_generator = ImageDataGenerator(
rescale=1./255
)
val_image_generator = ImageDataGenerator(
rescale=1./255
)
#train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
#val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,
directory=val_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
#train_dist_dataset = strategy.experimental_distribute_dataset()
#val_dist_dataset = strategy.experimental_distribute_dataset()
sample_train_images, _ = next(train_data_gen)
sample_val_images, _ = next(val_data_gen)
# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.
def plotImages(images_arr):
fig, axes = plt.subplots(1, 3, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
#plotImages(sample_train_images[0:3])
#plotImages(sample_val_images[0:3])
#get_ipython().run_line_magic('load_ext', 'tensorboard')
#get_ipython().system('rm -rf ./clusterlogs/')
#log_dir="clusterlogs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#log_dir="clusterlogs"
#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)
#%tensorboard --logdir clusterlogs --port 6006
strategy.num_replicas_in_sync
## Compute global batch size using number of replicas.
#GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
BATCH_SIZE_PER_REPLICA = 8
print("BATCH_SIZE_PER_REPLICA", BATCH_SIZE_PER_REPLICA)
print("strategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
global_batch_size = (BATCH_SIZE_PER_REPLICA *
strategy.num_replicas_in_sync)
print("global_batch_size", global_batch_size)
print("total_train", total_train)
print("total_val ", total_val)
print("batch_size", batch_size)
print("total_train // batch_size", total_train // batch_size)
print("total_val // batch_size", total_val // batch_size)
#.batch(global_batch_size)
#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)
#dataset = dataset.batch(global_batch_size)
#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}
#learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]
def get_uncompiled_model():
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])
return model
#get_uncompiled_model()
def get_compiled_model():
model = get_uncompiled_model()
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
return model
# Create a checkpoint directory to store the checkpoints.
#checkpoint_dir = './training_checkpoints'
#checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
#callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='tmp/keras-ckpt')]
#callbacks=[tensorboard_callback,callbacks]
#def get_fit_model():
# model = get_compiled_model()
# model.fit(
# train_data_gen,
# steps_per_epoch=total_train // batch_size,
# epochs=epochs,
# validation_data=val_data_gen,
# validation_steps=total_val // batch_size,
# verbose=2
# )
#return model
with strategy.scope():
model = get_compiled_model()
history = model.fit(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
verbose=2
).batch(global_batch_size)
#model.summary()
print("TRAINING info")
print(train_dir)
print(train_good_dir)
print(train_bad_dir)
print(train_image_generator)
print(train_data_gen)
#print(sample_train_images)
#print(history)
#model.to_json()
#history = model.fit(X, y, batch_size=32, epochs=40, validation_split=0.1)
model.save('data/models/FOO/wut-train-cluster2.tf')
model.save('data/models/FOO/wut-train-cluster2.h5')
model.save_weights('data/models/FOO/wut-weights-train-cluster2.tf')
model.save_weights('data/models/FOO/wut-weights-train-cluster2.h5')
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()