satnogs-wut/wut-worker.py

80 lines
2.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# coding: utf-8
#
2020-01-17 18:46:36 -07:00
# wut-worker.py --- Runs on worker nodes.
#
2020-01-17 18:46:36 -07:00
# Start with wut-worker shell script to set correct
# environmental variables.
from __future__ import absolute_import, division, print_function, unicode_literals
import simplejson as json
2020-01-17 18:21:46 -07:00
import os
2020-01-17 18:46:36 -07:00
import numpy as np
import tensorflow as tf
import tensorflow.python.keras
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.preprocessing.image import load_img
from tensorflow.python.keras.preprocessing.image import img_to_array
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, concatenate
2020-01-18 12:58:55 -07:00
#batch_size = 32
#epochs = 4
2020-01-17 18:46:36 -07:00
IMG_HEIGHT = 416
IMG_WIDTH= 804
2020-01-17 18:29:05 -07:00
2020-01-18 12:58:55 -07:00
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
2020-01-18 15:02:34 -07:00
2020-01-18 12:58:55 -07:00
def get_uncompiled_model():
2020-01-18 15:02:34 -07:00
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])
2020-01-18 12:58:55 -07:00
return model
def get_compiled_model():
model = get_uncompiled_model()
2020-01-17 18:46:36 -07:00
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
2020-01-18 12:58:55 -07:00
return model
def get_fit_model():
model = get_compiled_model()
model.fit(
2020-01-18 15:43:33 -07:00
model )
2020-01-18 12:58:55 -07:00
return model
2020-01-18 15:02:34 -07:00
#def get_fit_model():
# model = get_compiled_model()
# model.fit(
# train_data_gen,
# steps_per_epoch=total_train // batch_size,
# epochs=epochs,
# validation_data=val_data_gen,
# validation_steps=total_val // batch_size,
# verbose=2
# )
# return model
2020-01-18 12:58:55 -07:00
with strategy.scope():
get_uncompiled_model()
get_compiled_model()
get_fit_model()
2020-01-17 18:29:05 -07:00