2020-01-18 15:43:33 -07:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# coding: utf-8
|
2020-01-18 16:06:06 -07:00
|
|
|
# wut-train-cluster-fn.py --- What U Think? SatNOGS Observation AI, training application cluster edition.
|
2020-01-18 15:43:33 -07:00
|
|
|
#
|
|
|
|
# https://spacecruft.org/spacecruft/satnogs-wut
|
|
|
|
#
|
|
|
|
# Based on data/train and data/val directories builds a wut.tf file.
|
|
|
|
# GPLv3+
|
|
|
|
# Built using Jupyter, Tensorflow, Keras
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
import simplejson as json
|
|
|
|
import datetime
|
|
|
|
import tensorflow as tf
|
|
|
|
import tensorflow.python.keras
|
|
|
|
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
|
|
|
|
from tensorflow.python.keras import optimizers
|
|
|
|
from tensorflow.python.keras import Sequential
|
|
|
|
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
|
|
|
|
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
|
|
|
|
from tensorflow.python.keras.layers import Input, concatenate
|
|
|
|
from tensorflow.python.keras.models import load_model
|
|
|
|
from tensorflow.python.keras.models import Model
|
|
|
|
from tensorflow.python.keras.preprocessing import image
|
|
|
|
from tensorflow.python.keras.preprocessing.image import img_to_array
|
|
|
|
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
|
|
|
|
from tensorflow.python.keras.preprocessing.image import load_img
|
2020-01-18 17:05:53 -07:00
|
|
|
|
|
|
|
# Environmetal variables are set by shell script that launches this python script.
|
2020-01-18 16:06:06 -07:00
|
|
|
#os.environ["TF_CONFIG"] = json.dumps({
|
|
|
|
# "cluster": {
|
|
|
|
# "worker": [ "10.100.100.130:2222", "ml1:2222", "ml2:2222", "ml3:2222", "ml4:2222", "ml5:2222" ]
|
|
|
|
# },
|
|
|
|
# "task": {"type": "worker", "index": 0 },
|
|
|
|
# "num_workers": 5
|
|
|
|
#})
|
2020-01-18 15:43:33 -07:00
|
|
|
IMG_HEIGHT = 416
|
|
|
|
IMG_WIDTH= 804
|
|
|
|
batch_size = 32
|
|
|
|
epochs = 4
|
|
|
|
|
2020-01-18 16:06:06 -07:00
|
|
|
# XXX
|
2020-01-18 16:14:14 -07:00
|
|
|
#tf.keras.backend.clear_session()
|
2020-01-18 15:43:33 -07:00
|
|
|
|
|
|
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
|
|
|
tf.distribute.experimental.CollectiveCommunication.RING)
|
|
|
|
|
|
|
|
train_dir = os.path.join('data/', 'train')
|
|
|
|
val_dir = os.path.join('data/', 'val')
|
|
|
|
train_good_dir = os.path.join(train_dir, 'good')
|
|
|
|
train_bad_dir = os.path.join(train_dir, 'bad')
|
|
|
|
val_good_dir = os.path.join(val_dir, 'good')
|
|
|
|
val_bad_dir = os.path.join(val_dir, 'bad')
|
|
|
|
num_train_good = len(os.listdir(train_good_dir))
|
|
|
|
num_train_bad = len(os.listdir(train_bad_dir))
|
|
|
|
num_val_good = len(os.listdir(val_good_dir))
|
|
|
|
num_val_bad = len(os.listdir(val_bad_dir))
|
|
|
|
total_train = num_train_good + num_train_bad
|
|
|
|
total_val = num_val_good + num_val_bad
|
|
|
|
|
|
|
|
print('total training good images:', num_train_good)
|
|
|
|
print('total training bad images:', num_train_bad)
|
|
|
|
print("--")
|
|
|
|
print("Total training images:", total_train)
|
|
|
|
print('total validation good images:', num_val_good)
|
|
|
|
print('total validation bad images:', num_val_bad)
|
|
|
|
print("--")
|
|
|
|
print("Total validation images:", total_val)
|
|
|
|
print("--")
|
|
|
|
print("Reduce training and validation set when testing")
|
|
|
|
#total_train = 16
|
|
|
|
#total_val = 16
|
|
|
|
print("Reduced training images:", total_train)
|
|
|
|
print("Reduced validation images:", total_val)
|
|
|
|
|
|
|
|
train_image_generator = ImageDataGenerator(
|
|
|
|
rescale=1./255
|
|
|
|
)
|
|
|
|
val_image_generator = ImageDataGenerator(
|
|
|
|
rescale=1./255
|
|
|
|
)
|
|
|
|
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
|
|
|
|
directory=train_dir,
|
|
|
|
shuffle=True,
|
|
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
|
|
class_mode='binary')
|
|
|
|
val_data_gen = val_image_generator.flow_from_directory(batch_size=batch_size,
|
|
|
|
directory=val_dir,
|
|
|
|
target_size=(IMG_HEIGHT, IMG_WIDTH),
|
|
|
|
class_mode='binary')
|
|
|
|
|
|
|
|
def get_uncompiled_model():
|
|
|
|
model = Sequential([
|
|
|
|
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
|
|
|
|
MaxPooling2D(),
|
|
|
|
Conv2D(32, 3, padding='same', activation='relu'),
|
|
|
|
MaxPooling2D(),
|
|
|
|
Conv2D(64, 3, padding='same', activation='relu'),
|
|
|
|
MaxPooling2D(),
|
|
|
|
Flatten(),
|
|
|
|
Dense(512, activation='relu'),
|
|
|
|
Dense(1, activation='sigmoid')
|
|
|
|
])
|
|
|
|
return model
|
|
|
|
|
|
|
|
def get_compiled_model():
|
|
|
|
model = get_uncompiled_model()
|
|
|
|
model.compile(optimizer='adam',
|
|
|
|
loss='binary_crossentropy',
|
|
|
|
metrics=['accuracy'])
|
|
|
|
return model
|
|
|
|
|
2020-01-18 17:05:53 -07:00
|
|
|
#def get_fit_model():
|
|
|
|
# model = get_compiled_model()
|
|
|
|
# model.fit(
|
|
|
|
# train_data_gen,
|
|
|
|
# steps_per_epoch=total_train // batch_size,
|
|
|
|
# epochs=epochs,
|
|
|
|
# validation_data=val_data_gen,
|
|
|
|
# validation_steps=total_val // batch_size,
|
|
|
|
# verbose=2
|
|
|
|
# )
|
|
|
|
# return model
|
2020-01-18 15:43:33 -07:00
|
|
|
|
2020-01-18 16:08:21 -07:00
|
|
|
#with strategy.scope():
|
|
|
|
# get_uncompiled_model()
|
|
|
|
#with strategy.scope():
|
|
|
|
# get_compiled_model()
|
2020-01-18 15:43:33 -07:00
|
|
|
#with strategy.scope():
|
|
|
|
# get_fit_model()
|
2020-01-18 16:08:21 -07:00
|
|
|
|
2020-01-18 15:43:33 -07:00
|
|
|
#multi_worker_model = get_compiled_model()
|
|
|
|
#multi_worker_model.fit(
|
|
|
|
# x=train_data_gen,
|
|
|
|
# epochs=epochs,
|
|
|
|
# steps_per_epoch=total_train // batch_size
|
|
|
|
# )
|
|
|
|
|
|
|
|
with strategy.scope():
|
2020-01-18 17:05:53 -07:00
|
|
|
model = get_compiled_model()
|
|
|
|
model.fit(
|
|
|
|
train_data_gen,
|
|
|
|
steps_per_epoch=total_train // batch_size,
|
2020-01-18 15:43:33 -07:00
|
|
|
epochs=epochs,
|
2020-01-18 17:05:53 -07:00
|
|
|
validation_data=val_data_gen,
|
|
|
|
validation_steps=total_val // batch_size,
|
|
|
|
verbose=2
|
2020-01-18 15:43:33 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
print("TRAINING info")
|
|
|
|
print(train_dir)
|
|
|
|
print(train_good_dir)
|
|
|
|
print(train_bad_dir)
|
|
|
|
print(train_image_generator)
|
|
|
|
print(train_data_gen)
|
|
|
|
|
|
|
|
# The End
|
|
|
|
|