Change trace with NNUE eval support
This patch adds some more output to the `eval` command. It adds a board display with estimated piece values (method is remove-piece, evaluate, put-piece), and splits the NNUE evaluation with (psqt,layers) for each bucket for the NNUE net. Example: ``` ./stockfish position fen 3Qb1k1/1r2ppb1/pN1n2q1/Pp1Pp1Pr/4P2p/4BP2/4B1R1/1R5K b - - 11 40 eval Contributing terms for the classical eval: +------------+-------------+-------------+-------------+ | Term | White | Black | Total | | | MG EG | MG EG | MG EG | +------------+-------------+-------------+-------------+ | Material | ---- ---- | ---- ---- | -0.73 -1.55 | | Imbalance | ---- ---- | ---- ---- | -0.21 -0.17 | | Pawns | 0.35 -0.00 | 0.19 -0.26 | 0.16 0.25 | | Knights | 0.04 -0.08 | 0.12 -0.01 | -0.08 -0.07 | | Bishops | -0.34 -0.87 | -0.17 -0.61 | -0.17 -0.26 | | Rooks | 0.12 0.00 | 0.08 0.00 | 0.04 0.00 | | Queens | 0.00 0.00 | -0.27 -0.07 | 0.27 0.07 | | Mobility | 0.84 1.76 | 0.01 0.66 | 0.83 1.10 | |King safety | -0.99 -0.17 | -0.72 -0.10 | -0.27 -0.07 | | Threats | 0.27 0.27 | 0.73 0.86 | -0.46 -0.59 | | Passed | 0.00 0.00 | 0.79 0.82 | -0.79 -0.82 | | Space | 0.61 0.00 | 0.24 0.00 | 0.37 0.00 | | Winnable | ---- ---- | ---- ---- | 0.00 -0.03 | +------------+-------------+-------------+-------------+ | Total | ---- ---- | ---- ---- | -1.03 -2.14 | +------------+-------------+-------------+-------------+ NNUE derived piece values: +-------+-------+-------+-------+-------+-------+-------+-------+ | | | | Q | b | | k | | | | | | +12.4 | -1.62 | | | | +-------+-------+-------+-------+-------+-------+-------+-------+ | | r | | | p | p | b | | | | -3.89 | | | -0.84 | -1.19 | -3.32 | | +-------+-------+-------+-------+-------+-------+-------+-------+ | p | N | | n | | | q | | | -1.81 | +3.71 | | -4.82 | | | -5.04 | | +-------+-------+-------+-------+-------+-------+-------+-------+ | P | p | | P | p | | P | r | | +1.16 | -0.91 | | +0.55 | +0.12 | | +0.50 | -4.02 | +-------+-------+-------+-------+-------+-------+-------+-------+ | | | | | P | | | p | | | | | | +2.33 | | | +1.17 | +-------+-------+-------+-------+-------+-------+-------+-------+ | | | | | B | P | | | | | | | | +4.79 | +1.54 | | | +-------+-------+-------+-------+-------+-------+-------+-------+ | | | | | B | | R | | | | | | | +4.54 | | +6.03 | | +-------+-------+-------+-------+-------+-------+-------+-------+ | | R | | | | | | K | | | +4.81 | | | | | | | +-------+-------+-------+-------+-------+-------+-------+-------+ NNUE network contributions (Black to move) +------------+------------+------------+------------+ | Bucket | Material | Positional | Total | | | (PSQT) | (Layers) | | +------------+------------+------------+------------+ | 0 | + 0.32 | - 1.46 | - 1.13 | | 1 | + 0.25 | - 0.68 | - 0.43 | | 2 | + 0.46 | - 1.72 | - 1.25 | | 3 | + 0.55 | - 1.80 | - 1.25 | | 4 | + 0.48 | - 1.77 | - 1.29 | | 5 | + 0.40 | - 2.00 | - 1.60 | | 6 | + 0.57 | - 2.12 | - 1.54 | <-- this bucket is used | 7 | + 3.38 | - 2.00 | + 1.37 | +------------+------------+------------+------------+ Classical evaluation -1.00 (white side) NNUE evaluation +1.54 (white side) Final evaluation +2.38 (white side) [with scaled NNUE, hybrid, ...] ``` Also renames the export_net() function to save_eval() while there. closes https://github.com/official-stockfish/Stockfish/pull/3562 No functional changepull/3570/head
parent
0171b506ec
commit
2e745956c0
|
@ -114,30 +114,6 @@ namespace Eval {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// NNUE::export_net() exports the currently loaded network to a file
|
|
||||||
void NNUE::export_net(const std::optional<std::string>& filename) {
|
|
||||||
std::string actualFilename;
|
|
||||||
|
|
||||||
if (filename.has_value())
|
|
||||||
actualFilename = filename.value();
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (eval_file_loaded != EvalFileDefaultName)
|
|
||||||
{
|
|
||||||
sync_cout << "Failed to export a net. A non-embedded net can only be saved if the filename is specified." << sync_endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
actualFilename = EvalFileDefaultName;
|
|
||||||
}
|
|
||||||
|
|
||||||
ofstream stream(actualFilename, std::ios_base::binary);
|
|
||||||
|
|
||||||
if (save_eval(stream))
|
|
||||||
sync_cout << "Network saved successfully to " << actualFilename << "." << sync_endl;
|
|
||||||
else
|
|
||||||
sync_cout << "Failed to export a net." << sync_endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// NNUE::verify() verifies that the last net used was loaded successfully
|
/// NNUE::verify() verifies that the last net used was loaded successfully
|
||||||
void NNUE::verify() {
|
void NNUE::verify() {
|
||||||
|
|
||||||
|
@ -204,7 +180,7 @@ namespace Trace {
|
||||||
else
|
else
|
||||||
os << scores[t][WHITE] << " | " << scores[t][BLACK];
|
os << scores[t][WHITE] << " | " << scores[t][BLACK];
|
||||||
|
|
||||||
os << " | " << scores[t][WHITE] - scores[t][BLACK] << "\n";
|
os << " | " << scores[t][WHITE] - scores[t][BLACK] << " |\n";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1150,7 +1126,7 @@ Value Eval::evaluate(const Position& pos) {
|
||||||
/// descriptions and values of each evaluation term. Useful for debugging.
|
/// descriptions and values of each evaluation term. Useful for debugging.
|
||||||
/// Trace scores are from white's point of view
|
/// Trace scores are from white's point of view
|
||||||
|
|
||||||
std::string Eval::trace(const Position& pos) {
|
std::string Eval::trace(Position& pos) {
|
||||||
|
|
||||||
if (pos.checkers())
|
if (pos.checkers())
|
||||||
return "Final evaluation: none (in check)";
|
return "Final evaluation: none (in check)";
|
||||||
|
@ -1167,39 +1143,48 @@ std::string Eval::trace(const Position& pos) {
|
||||||
v = Evaluation<TRACE>(pos).value();
|
v = Evaluation<TRACE>(pos).value();
|
||||||
|
|
||||||
ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2)
|
ss << std::showpoint << std::noshowpos << std::fixed << std::setprecision(2)
|
||||||
<< " Term | White | Black | Total \n"
|
<< " Contributing terms for the classical eval:\n"
|
||||||
<< " | MG EG | MG EG | MG EG \n"
|
<< "+------------+-------------+-------------+-------------+\n"
|
||||||
<< " ------------+-------------+-------------+------------\n"
|
<< "| Term | White | Black | Total |\n"
|
||||||
<< " Material | " << Term(MATERIAL)
|
<< "| | MG EG | MG EG | MG EG |\n"
|
||||||
<< " Imbalance | " << Term(IMBALANCE)
|
<< "+------------+-------------+-------------+-------------+\n"
|
||||||
<< " Pawns | " << Term(PAWN)
|
<< "| Material | " << Term(MATERIAL)
|
||||||
<< " Knights | " << Term(KNIGHT)
|
<< "| Imbalance | " << Term(IMBALANCE)
|
||||||
<< " Bishops | " << Term(BISHOP)
|
<< "| Pawns | " << Term(PAWN)
|
||||||
<< " Rooks | " << Term(ROOK)
|
<< "| Knights | " << Term(KNIGHT)
|
||||||
<< " Queens | " << Term(QUEEN)
|
<< "| Bishops | " << Term(BISHOP)
|
||||||
<< " Mobility | " << Term(MOBILITY)
|
<< "| Rooks | " << Term(ROOK)
|
||||||
<< " King safety | " << Term(KING)
|
<< "| Queens | " << Term(QUEEN)
|
||||||
<< " Threats | " << Term(THREAT)
|
<< "| Mobility | " << Term(MOBILITY)
|
||||||
<< " Passed | " << Term(PASSED)
|
<< "|King safety | " << Term(KING)
|
||||||
<< " Space | " << Term(SPACE)
|
<< "| Threats | " << Term(THREAT)
|
||||||
<< " Winnable | " << Term(WINNABLE)
|
<< "| Passed | " << Term(PASSED)
|
||||||
<< " ------------+-------------+-------------+------------\n"
|
<< "| Space | " << Term(SPACE)
|
||||||
<< " Total | " << Term(TOTAL);
|
<< "| Winnable | " << Term(WINNABLE)
|
||||||
|
<< "+------------+-------------+-------------+-------------+\n"
|
||||||
v = pos.side_to_move() == WHITE ? v : -v;
|
<< "| Total | " << Term(TOTAL)
|
||||||
|
<< "+------------+-------------+-------------+-------------+\n";
|
||||||
ss << "\nClassical evaluation: " << to_cp(v) << " (white side)\n";
|
|
||||||
|
|
||||||
|
if (Eval::useNNUE)
|
||||||
|
ss << '\n' << NNUE::trace(pos) << '\n';
|
||||||
|
|
||||||
|
ss << std::showpoint << std::showpos << std::fixed << std::setprecision(2) << std::setw(15);
|
||||||
|
|
||||||
|
v = pos.side_to_move() == WHITE ? v : -v;
|
||||||
|
ss << "\nClassical evaluation " << to_cp(v) << " (white side)\n";
|
||||||
if (Eval::useNNUE)
|
if (Eval::useNNUE)
|
||||||
{
|
{
|
||||||
v = NNUE::evaluate(pos);
|
v = NNUE::evaluate(pos, false);
|
||||||
v = pos.side_to_move() == WHITE ? v : -v;
|
v = pos.side_to_move() == WHITE ? v : -v;
|
||||||
ss << "\nNNUE evaluation: " << to_cp(v) << " (white side)\n";
|
ss << "NNUE evaluation " << to_cp(v) << " (white side)\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
v = evaluate(pos);
|
v = evaluate(pos);
|
||||||
v = pos.side_to_move() == WHITE ? v : -v;
|
v = pos.side_to_move() == WHITE ? v : -v;
|
||||||
ss << "\nFinal evaluation: " << to_cp(v) << " (white side)\n";
|
ss << "Final evaluation " << to_cp(v) << " (white side)";
|
||||||
|
if (Eval::useNNUE)
|
||||||
|
ss << " [with scaled NNUE, hybrid, ...]";
|
||||||
|
ss << "\n";
|
||||||
|
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ class Position;
|
||||||
|
|
||||||
namespace Eval {
|
namespace Eval {
|
||||||
|
|
||||||
std::string trace(const Position& pos);
|
std::string trace(Position& pos);
|
||||||
Value evaluate(const Position& pos);
|
Value evaluate(const Position& pos);
|
||||||
|
|
||||||
extern bool useNNUE;
|
extern bool useNNUE;
|
||||||
|
@ -43,12 +43,15 @@ namespace Eval {
|
||||||
|
|
||||||
namespace NNUE {
|
namespace NNUE {
|
||||||
|
|
||||||
|
std::string trace(Position& pos);
|
||||||
Value evaluate(const Position& pos, bool adjusted = false);
|
Value evaluate(const Position& pos, bool adjusted = false);
|
||||||
|
|
||||||
|
void init();
|
||||||
|
void verify();
|
||||||
|
|
||||||
bool load_eval(std::string name, std::istream& stream);
|
bool load_eval(std::string name, std::istream& stream);
|
||||||
bool save_eval(std::ostream& stream);
|
bool save_eval(std::ostream& stream);
|
||||||
void init();
|
bool save_eval(const std::optional<std::string>& filename);
|
||||||
void export_net(const std::optional<std::string>& filename);
|
|
||||||
void verify();
|
|
||||||
|
|
||||||
} // namespace NNUE
|
} // namespace NNUE
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "../evaluate.h"
|
#include "../evaluate.h"
|
||||||
#include "../position.h"
|
#include "../position.h"
|
||||||
|
@ -175,6 +178,220 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return static_cast<Value>( sum / OutputScale );
|
return static_cast<Value>( sum / OutputScale );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NnueEvalTrace {
|
||||||
|
static_assert(LayerStacks == PSQTBuckets);
|
||||||
|
|
||||||
|
Value psqt[LayerStacks];
|
||||||
|
Value positional[LayerStacks];
|
||||||
|
std::size_t correctBucket;
|
||||||
|
};
|
||||||
|
|
||||||
|
static NnueEvalTrace trace_evaluate(const Position& pos) {
|
||||||
|
|
||||||
|
// We manually align the arrays on the stack because with gcc < 9.3
|
||||||
|
// overaligning stack variables with alignas() doesn't work correctly.
|
||||||
|
|
||||||
|
constexpr uint64_t alignment = CacheLineSize;
|
||||||
|
|
||||||
|
#if defined(ALIGNAS_ON_STACK_VARIABLES_BROKEN)
|
||||||
|
TransformedFeatureType transformedFeaturesUnaligned[
|
||||||
|
FeatureTransformer::BufferSize + alignment / sizeof(TransformedFeatureType)];
|
||||||
|
char bufferUnaligned[Network::BufferSize + alignment];
|
||||||
|
|
||||||
|
auto* transformedFeatures = align_ptr_up<alignment>(&transformedFeaturesUnaligned[0]);
|
||||||
|
auto* buffer = align_ptr_up<alignment>(&bufferUnaligned[0]);
|
||||||
|
#else
|
||||||
|
alignas(alignment)
|
||||||
|
TransformedFeatureType transformedFeatures[FeatureTransformer::BufferSize];
|
||||||
|
alignas(alignment) char buffer[Network::BufferSize];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ASSERT_ALIGNED(transformedFeatures, alignment);
|
||||||
|
ASSERT_ALIGNED(buffer, alignment);
|
||||||
|
|
||||||
|
NnueEvalTrace t{};
|
||||||
|
t.correctBucket = (pos.count<ALL_PIECES>() - 1) / 4;
|
||||||
|
for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket) {
|
||||||
|
const auto psqt = featureTransformer->transform(pos, transformedFeatures, bucket);
|
||||||
|
const auto output = network[bucket]->propagate(transformedFeatures, buffer);
|
||||||
|
|
||||||
|
int materialist = psqt;
|
||||||
|
int positional = output[0];
|
||||||
|
|
||||||
|
t.psqt[bucket] = static_cast<Value>( materialist / OutputScale );
|
||||||
|
t.positional[bucket] = static_cast<Value>( positional / OutputScale );
|
||||||
|
}
|
||||||
|
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::string PieceToChar(" PNBRQK pnbrqk");
|
||||||
|
|
||||||
|
// Requires the buffer to have capacity for at least 5 values
|
||||||
|
static void format_cp_compact(Value v, char* buffer) {
|
||||||
|
|
||||||
|
buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' ');
|
||||||
|
|
||||||
|
int cp = (int)(std::abs(100.0 * double(v) / PawnValueEg));
|
||||||
|
|
||||||
|
if (cp >= 10000)
|
||||||
|
{
|
||||||
|
buffer[1] = '0' + cp / 10000; cp %= 10000;
|
||||||
|
buffer[2] = '0' + cp / 1000; cp %= 1000;
|
||||||
|
buffer[3] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[4] = ' ';
|
||||||
|
}
|
||||||
|
else if (cp >= 1000)
|
||||||
|
{
|
||||||
|
buffer[1] = '0' + cp / 1000; cp %= 1000;
|
||||||
|
buffer[2] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[3] = '.';
|
||||||
|
buffer[4] = '0' + cp / 10;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
buffer[1] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[2] = '.';
|
||||||
|
buffer[3] = '0' + cp / 10; cp %= 10;
|
||||||
|
buffer[4] = '0' + cp / 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Requires the buffer to have capacity for at least 7 values
|
||||||
|
static void format_cp_aligned_dot(Value v, char* buffer) {
|
||||||
|
buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' ');
|
||||||
|
|
||||||
|
int cp = (int)(std::abs(100.0 * double(v) / PawnValueEg));
|
||||||
|
|
||||||
|
if (cp >= 10000)
|
||||||
|
{
|
||||||
|
buffer[1] = '0' + cp / 10000; cp %= 10000;
|
||||||
|
buffer[2] = '0' + cp / 1000; cp %= 1000;
|
||||||
|
buffer[3] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[4] = '.';
|
||||||
|
buffer[5] = '0' + cp / 10; cp %= 10;
|
||||||
|
buffer[6] = '0' + cp;
|
||||||
|
}
|
||||||
|
else if (cp >= 1000)
|
||||||
|
{
|
||||||
|
buffer[1] = ' ';
|
||||||
|
buffer[2] = '0' + cp / 1000; cp %= 1000;
|
||||||
|
buffer[3] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[4] = '.';
|
||||||
|
buffer[5] = '0' + cp / 10; cp %= 10;
|
||||||
|
buffer[6] = '0' + cp;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
buffer[1] = ' ';
|
||||||
|
buffer[2] = ' ';
|
||||||
|
buffer[3] = '0' + cp / 100; cp %= 100;
|
||||||
|
buffer[4] = '.';
|
||||||
|
buffer[5] = '0' + cp / 10; cp %= 10;
|
||||||
|
buffer[6] = '0' + cp / 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// trace() returns a string with the value of each piece on a board,
|
||||||
|
// and a table for (PSQT, Layers) values bucket by bucket.
|
||||||
|
|
||||||
|
std::string trace(Position& pos) {
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
|
||||||
|
char board[3*8+1][8*8+2];
|
||||||
|
std::memset(board, ' ', sizeof(board));
|
||||||
|
for (int row = 0; row < 3*8+1; ++row)
|
||||||
|
board[row][8*8+1] = '\0';
|
||||||
|
|
||||||
|
// A lambda to output one box of the board
|
||||||
|
auto writeSquare = [&board](File file, Rank rank, Piece pc, Value value) {
|
||||||
|
|
||||||
|
const int x = ((int)file) * 8;
|
||||||
|
const int y = (7 - (int)rank) * 3;
|
||||||
|
for (int i = 1; i < 8; ++i)
|
||||||
|
board[y][x+i] = board[y+3][x+i] = '-';
|
||||||
|
for (int i = 1; i < 3; ++i)
|
||||||
|
board[y+i][x] = board[y+i][x+8] = '|';
|
||||||
|
board[y][x] = board[y][x+8] = board[y+3][x+8] = board[y+3][x] = '+';
|
||||||
|
if (pc != NO_PIECE)
|
||||||
|
board[y+1][x+4] = PieceToChar[pc];
|
||||||
|
if (value != VALUE_NONE)
|
||||||
|
format_cp_compact(value, &board[y+2][x+2]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// We estimate the value of each piece by doing a differential evaluation from
|
||||||
|
// the current base eval, simulating the removal of the piece from its square.
|
||||||
|
Value base = evaluate(pos);
|
||||||
|
base = pos.side_to_move() == WHITE ? base : -base;
|
||||||
|
|
||||||
|
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||||
|
for (Rank r = RANK_1; r <= RANK_8; ++r)
|
||||||
|
{
|
||||||
|
Square sq = make_square(f, r);
|
||||||
|
Piece pc = pos.piece_on(sq);
|
||||||
|
Value v = VALUE_NONE;
|
||||||
|
|
||||||
|
if (pc != NO_PIECE && type_of(pc) != KING)
|
||||||
|
{
|
||||||
|
auto st = pos.state();
|
||||||
|
|
||||||
|
pos.remove_piece(sq);
|
||||||
|
st->accumulator.computed[WHITE] = false;
|
||||||
|
st->accumulator.computed[BLACK] = false;
|
||||||
|
|
||||||
|
Value eval = evaluate(pos);
|
||||||
|
eval = pos.side_to_move() == WHITE ? eval : -eval;
|
||||||
|
v = base - eval;
|
||||||
|
|
||||||
|
pos.put_piece(pc, sq);
|
||||||
|
st->accumulator.computed[WHITE] = false;
|
||||||
|
st->accumulator.computed[BLACK] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
writeSquare(f, r, pc, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss << " NNUE derived piece values:\n";
|
||||||
|
for (int row = 0; row < 3*8+1; ++row)
|
||||||
|
ss << board[row] << '\n';
|
||||||
|
ss << '\n';
|
||||||
|
|
||||||
|
auto t = trace_evaluate(pos);
|
||||||
|
|
||||||
|
ss << " NNUE network contributions "
|
||||||
|
<< (pos.side_to_move() == WHITE ? "(White to move)" : "(Black to move)") << std::endl
|
||||||
|
<< "+------------+------------+------------+------------+\n"
|
||||||
|
<< "| Bucket | Material | Positional | Total |\n"
|
||||||
|
<< "| | (PSQT) | (Layers) | |\n"
|
||||||
|
<< "+------------+------------+------------+------------+\n";
|
||||||
|
|
||||||
|
for (std::size_t bucket = 0; bucket < LayerStacks; ++bucket)
|
||||||
|
{
|
||||||
|
char buffer[3][8];
|
||||||
|
std::memset(buffer, '\0', sizeof(buffer));
|
||||||
|
|
||||||
|
format_cp_aligned_dot(t.psqt[bucket], buffer[0]);
|
||||||
|
format_cp_aligned_dot(t.positional[bucket], buffer[1]);
|
||||||
|
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], buffer[2]);
|
||||||
|
|
||||||
|
ss << "| " << bucket << " "
|
||||||
|
<< " | " << buffer[0] << " "
|
||||||
|
<< " | " << buffer[1] << " "
|
||||||
|
<< " | " << buffer[2] << " "
|
||||||
|
<< " |";
|
||||||
|
if (bucket == t.correctBucket)
|
||||||
|
ss << " <-- this bucket is used";
|
||||||
|
ss << '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
ss << "+------------+------------+------------+------------+\n";
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Load eval, from a file stream or a memory stream
|
// Load eval, from a file stream or a memory stream
|
||||||
bool load_eval(std::string name, std::istream& stream) {
|
bool load_eval(std::string name, std::istream& stream) {
|
||||||
|
|
||||||
|
@ -192,4 +409,35 @@ namespace Stockfish::Eval::NNUE {
|
||||||
return write_parameters(stream);
|
return write_parameters(stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Save eval, to a file given by its name
|
||||||
|
bool save_eval(const std::optional<std::string>& filename) {
|
||||||
|
|
||||||
|
std::string actualFilename;
|
||||||
|
std::string msg;
|
||||||
|
|
||||||
|
if (filename.has_value())
|
||||||
|
actualFilename = filename.value();
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (eval_file_loaded != EvalFileDefaultName)
|
||||||
|
{
|
||||||
|
msg = "Failed to export a net. A non-embedded net can only be saved if the filename is specified";
|
||||||
|
|
||||||
|
sync_cout << msg << sync_endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
actualFilename = EvalFileDefaultName;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ofstream stream(actualFilename, std::ios_base::binary);
|
||||||
|
bool saved = save_eval(stream);
|
||||||
|
|
||||||
|
msg = saved ? "Network saved successfully to " + actualFilename
|
||||||
|
: "Failed to export a net";
|
||||||
|
|
||||||
|
sync_cout << msg << sync_endl;
|
||||||
|
return saved;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace Stockfish::Eval::NNUE
|
} // namespace Stockfish::Eval::NNUE
|
||||||
|
|
|
@ -171,6 +171,9 @@ public:
|
||||||
// Used by NNUE
|
// Used by NNUE
|
||||||
StateInfo* state() const;
|
StateInfo* state() const;
|
||||||
|
|
||||||
|
void put_piece(Piece pc, Square s);
|
||||||
|
void remove_piece(Square s);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Initialization helpers (used while setting up a position)
|
// Initialization helpers (used while setting up a position)
|
||||||
void set_castling_right(Color c, Square rfrom);
|
void set_castling_right(Color c, Square rfrom);
|
||||||
|
@ -178,8 +181,6 @@ private:
|
||||||
void set_check_info(StateInfo* si) const;
|
void set_check_info(StateInfo* si) const;
|
||||||
|
|
||||||
// Other helpers
|
// Other helpers
|
||||||
void put_piece(Piece pc, Square s);
|
|
||||||
void remove_piece(Square s);
|
|
||||||
void move_piece(Square from, Square to);
|
void move_piece(Square from, Square to);
|
||||||
template<bool Do>
|
template<bool Do>
|
||||||
void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto);
|
void do_castling(Color us, Square from, Square& to, Square& rfrom, Square& rto);
|
||||||
|
@ -386,7 +387,7 @@ inline void Position::remove_piece(Square s) {
|
||||||
byTypeBB[ALL_PIECES] ^= s;
|
byTypeBB[ALL_PIECES] ^= s;
|
||||||
byTypeBB[type_of(pc)] ^= s;
|
byTypeBB[type_of(pc)] ^= s;
|
||||||
byColorBB[color_of(pc)] ^= s;
|
byColorBB[color_of(pc)] ^= s;
|
||||||
/* board[s] = NO_PIECE; Not needed, overwritten by the capturing one */
|
board[s] = NO_PIECE;
|
||||||
pieceCount[pc]--;
|
pieceCount[pc]--;
|
||||||
pieceCount[make_piece(color_of(pc), ALL_PIECES)]--;
|
pieceCount[make_piece(color_of(pc), ALL_PIECES)]--;
|
||||||
psq -= PSQT::psq[pc][s];
|
psq -= PSQT::psq[pc][s];
|
||||||
|
|
10
src/uci.cpp
10
src/uci.cpp
|
@ -277,13 +277,13 @@ void UCI::loop(int argc, char* argv[]) {
|
||||||
else if (token == "d") sync_cout << pos << sync_endl;
|
else if (token == "d") sync_cout << pos << sync_endl;
|
||||||
else if (token == "eval") trace_eval(pos);
|
else if (token == "eval") trace_eval(pos);
|
||||||
else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
|
else if (token == "compiler") sync_cout << compiler_info() << sync_endl;
|
||||||
else if (token == "export_net") {
|
else if (token == "export_net")
|
||||||
|
{
|
||||||
std::optional<std::string> filename;
|
std::optional<std::string> filename;
|
||||||
std::string f;
|
std::string f;
|
||||||
if (is >> skipws >> f) {
|
if (is >> skipws >> f)
|
||||||
filename = f;
|
filename = f;
|
||||||
}
|
Eval::NNUE::save_eval(filename);
|
||||||
Eval::NNUE::export_net(filename);
|
|
||||||
}
|
}
|
||||||
else if (!token.empty() && token[0] != '#')
|
else if (!token.empty() && token[0] != '#')
|
||||||
sync_cout << "Unknown command: " << cmd << sync_endl;
|
sync_cout << "Unknown command: " << cmd << sync_endl;
|
||||||
|
|
Loading…
Reference in New Issue