diff --git a/wut-ml b/wut-ml index c8494b6..6d0ae4b 100755 --- a/wut-ml +++ b/wut-ml @@ -49,12 +49,13 @@ model.add(Dense(64)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(1)) -model.add(Activation('sigmoid')) +model.add(Activation('softmax')) +#model.compile(loss='categorical_crossentropy', model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy']) -model.fit(x=train_it, validation_data=val_it, epochs=16, verbose=2, workers=16, use_multiprocessing=True) +model.fit(x=train_it, validation_data=val_it, epochs=2, verbose=2, workers=16, use_multiprocessing=True) prediction = model.predict(x=test_it, batch_size=None, verbose=0, steps=None, use_multiprocessing=True) print(prediction)