satnogs-wut/wut-worker.py

256 lines
9.3 KiB
Python
Raw Normal View History

2020-01-28 17:52:23 -07:00
#!/usr/bin/env python3
#
2020-01-28 17:53:15 -07:00
# wut-worker.py
#
2020-01-28 17:53:15 -07:00
# wut --- What U Think? SatNOGS Observation AI, training application cluster edition.
2020-01-28 17:52:23 -07:00
#
# https://spacecruft.org/spacecruft/satnogs-wut
#
# Based on data/train and data/val directories builds a wut.tf file.
# GPLv3+
# Built using Jupyter, Tensorflow, Keras
from __future__ import absolute_import, division, print_function, unicode_literals
2020-01-28 17:52:23 -07:00
from __future__ import print_function
2020-01-17 18:21:46 -07:00
import os
2020-01-17 18:46:36 -07:00
import numpy as np
2020-01-28 17:52:23 -07:00
import simplejson as json
import datetime
2020-01-17 18:46:36 -07:00
import tensorflow as tf
import tensorflow.python.keras
2020-01-28 17:52:23 -07:00
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.python.keras import optimizers
2020-01-17 18:46:36 -07:00
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
2020-01-28 17:52:23 -07:00
from tensorflow.python.keras.layers import Input, concatenate
2020-01-17 18:46:36 -07:00
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.models import Model
2020-01-28 17:52:23 -07:00
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.preprocessing.image import img_to_array
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.preprocessing.image import load_img
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
get_ipython().run_line_magic('matplotlib', 'inline')
import matplotlib.pyplot as plt
import seaborn as sns
#from sklearn.decomposition import PCA
#from ipywidgets import interact, interactive, fixed, interact_manual
#import ipywidgets as widgets
#from IPython.display import display, Image
print('tf {}'.format(tf.__version__))
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["ml1-int:2222", "ml2-int:2222", "ml3-int:2222", "ml4-int:2222", "ml5-int:2222" ]
},
"task": {"type": "worker", "index": 0 },
"num_workers": 5
})
2020-01-17 18:46:36 -07:00
IMG_HEIGHT = 416
IMG_WIDTH= 804
2020-01-28 17:52:23 -07:00
batch_size = 32
epochs = 1
# Full size, machine barfs probably needs more RAM
#IMG_HEIGHT = 832
#IMG_WIDTH = 1606
# Good results
#batch_size = 128
#epochs = 6
tf.keras.backend.clear_session()
options = tf.data.Options()
#options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA
# XXX
#dataset = dataset.with_options(options)
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
tf.distribute.experimental.CollectiveCommunication.RING)
#mirrored_strategy = tf.distribute.MirroredStrategy(
# cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
root_data_dir = ('/srv/satnogs')
train_dir = os.path.join(root_data_dir, 'data/', 'train')
val_dir = os.path.join(root_data_dir,'data/', 'val')
train_good_dir = os.path.join(train_dir, 'good')
train_bad_dir = os.path.join(train_dir, 'bad')
val_good_dir = os.path.join(val_dir, 'good')
val_bad_dir = os.path.join(val_dir, 'bad')
num_train_good = len(os.listdir(train_good_dir))
num_train_bad = len(os.listdir(train_bad_dir))
num_val_good = len(os.listdir(val_good_dir))
num_val_bad = len(os.listdir(val_bad_dir))
total_train = num_train_good + num_train_bad
total_val = num_val_good + num_val_bad
print('total training good images:', num_train_good)
print('total training bad images:', num_train_bad)
print("--")
print("Total training images:", total_train)
print('total validation good images:', num_val_good)
print('total validation bad images:', num_val_bad)
print("--")
print("Total validation images:", total_val)
print("--")
print("Reduce training and validation set when testing")
total_train = 100
total_val = 100
print("Reduced training images:", total_train)
print("Reduced validation images:", total_val)
train_image_generator = ImageDataGenerator(
rescale=1./255
)
val_image_generator = ImageDataGenerator(
rescale=1./255
)
#train_data_gen = train_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
#val_data_gen = val_image_generator.flow_from_directory(batch_size=GLOBAL_BATCH_SIZE,
val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,
directory=val_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
#train_dist_dataset = strategy.experimental_distribute_dataset()
#val_dist_dataset = strategy.experimental_distribute_dataset()
sample_train_images, _ = next(train_data_gen)
sample_val_images, _ = next(val_data_gen)
# This function will plot images in the form of a grid with 1 row and 3 columns where images are placed in each column.
def plotImages(images_arr):
fig, axes = plt.subplots(1, 3, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
#plotImages(sample_train_images[0:3])
#plotImages(sample_val_images[0:3])
#get_ipython().run_line_magic('load_ext', 'tensorboard')
#get_ipython().system('rm -rf ./clusterlogs/')
#log_dir="clusterlogs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#log_dir="clusterlogs"
#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
#tensorboard_callback = tensorflow.keras.callbacks.TensorBoard(log_dir=log_dir)
#%tensorboard --logdir clusterlogs --port 6006
2020-01-17 18:29:05 -07:00
2020-01-28 17:52:23 -07:00
strategy.num_replicas_in_sync
2020-01-18 12:58:55 -07:00
2020-01-28 17:52:23 -07:00
## Compute global batch size using number of replicas.
#GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
BATCH_SIZE_PER_REPLICA = 8
print("BATCH_SIZE_PER_REPLICA", BATCH_SIZE_PER_REPLICA)
print("strategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
global_batch_size = (BATCH_SIZE_PER_REPLICA *
strategy.num_replicas_in_sync)
print("global_batch_size", global_batch_size)
print("total_train", total_train)
print("total_val ", total_val)
print("batch_size", batch_size)
print("total_train // batch_size", total_train // batch_size)
print("total_val // batch_size", total_val // batch_size)
#.batch(global_batch_size)
#dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100)
#dataset = dataset.batch(global_batch_size)
#LEARNING_RATES_BY_BATCH_SIZE = {5: 0.1, 10: 0.15}
#learning_rate = LEARNING_RATES_BY_BATCH_SIZE[global_batch_size]
2020-01-18 15:02:34 -07:00
2020-01-18 12:58:55 -07:00
def get_uncompiled_model():
2020-01-18 15:02:34 -07:00
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])
2020-01-18 12:58:55 -07:00
return model
2020-01-28 17:52:23 -07:00
#get_uncompiled_model()
2020-01-18 12:58:55 -07:00
def get_compiled_model():
model = get_uncompiled_model()
2020-01-17 18:46:36 -07:00
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
2020-01-18 12:58:55 -07:00
return model
2020-01-28 17:52:23 -07:00
# Create a checkpoint directory to store the checkpoints.
#checkpoint_dir = './training_checkpoints'
#checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
#callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='tmp/keras-ckpt')]
#callbacks=[tensorboard_callback,callbacks]
2020-01-18 12:58:55 -07:00
2020-01-18 15:02:34 -07:00
#def get_fit_model():
# model = get_compiled_model()
# model.fit(
# train_data_gen,
# steps_per_epoch=total_train // batch_size,
# epochs=epochs,
# validation_data=val_data_gen,
# validation_steps=total_val // batch_size,
# verbose=2
# )
2020-01-28 17:52:23 -07:00
#return model
2020-01-18 15:02:34 -07:00
2020-01-18 12:58:55 -07:00
with strategy.scope():
2020-01-28 17:52:23 -07:00
model = get_compiled_model()
history = model.fit(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
verbose=2
).batch(global_batch_size)
#model.summary()
print("TRAINING info")
print(train_dir)
print(train_good_dir)
print(train_bad_dir)
print(train_image_generator)
print(train_data_gen)
#print(sample_train_images)
#print(history)
#model.to_json()
#history = model.fit(X, y, batch_size=32, epochs=40, validation_split=0.1)
model.save('data/models/FOO/wut-train-cluster2.tf')
model.save('data/models/FOO/wut-train-cluster2.h5')
model.save_weights('data/models/FOO/wut-weights-train-cluster2.tf')
model.save_weights('data/models/FOO/wut-weights-train-cluster2.h5')
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
2020-01-17 18:29:05 -07:00