Add dedicated command for training data validation.
parent
5676a50807
commit
eac1d430b4
|
@ -0,0 +1,12 @@
|
|||
# validate_training_data
|
||||
|
||||
`validate_training_data` allows validation of training data of types `.plain`, `.bin`, and `.binpack`.
|
||||
|
||||
As all commands in stockfish `validate_training_data` can be invoked either from command line (as `stockfish.exe validate_training_data ...`) or in the interactive prompt.
|
||||
|
||||
The syntax of this command is as follows:
|
||||
```
|
||||
validate_training_data in_path
|
||||
```
|
||||
|
||||
`in_path` is the path to the file to validate. The type of the data is deduced based on its extension (one of `.plain`, `.bin`, `.binpack`).
|
|
@ -51,6 +51,7 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp
|
|||
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
|
||||
nnue/evaluate_nnue.cpp \
|
||||
nnue/features/half_ka_v2.cpp \
|
||||
tools/validate_training_data.cpp \
|
||||
tools/sfen_packer.cpp \
|
||||
tools/training_data_generator.cpp \
|
||||
tools/training_data_generator_nonpv.cpp \
|
||||
|
|
|
@ -7831,4 +7831,154 @@ namespace binpack
|
|||
|
||||
std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
inline void validatePlain(std::string inputPath)
|
||||
{
|
||||
constexpr std::size_t reportSize = 1000000;
|
||||
|
||||
std::cout << "Validating " << inputPath << '\n';
|
||||
|
||||
TrainingDataEntry e;
|
||||
|
||||
std::string key;
|
||||
std::string value;
|
||||
std::string move;
|
||||
|
||||
std::ifstream inputFile(inputPath);
|
||||
const auto base = inputFile.tellg();
|
||||
std::size_t numProcessedPositions = 0;
|
||||
std::size_t numProcessedPositionsBatch = 0;
|
||||
|
||||
for(;;)
|
||||
{
|
||||
inputFile >> key;
|
||||
if (!inputFile)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (key == "e"sv)
|
||||
{
|
||||
e.move = chess::uci::uciToMove(e.pos, move);
|
||||
if (!e.isValid())
|
||||
{
|
||||
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
|
||||
return;
|
||||
}
|
||||
|
||||
++numProcessedPositions;
|
||||
++numProcessedPositionsBatch;
|
||||
|
||||
if (numProcessedPositionsBatch >= reportSize)
|
||||
{
|
||||
numProcessedPositionsBatch -= reportSize;
|
||||
const auto cur = inputFile.tellg();
|
||||
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
inputFile >> std::ws;
|
||||
std::getline(inputFile, value, '\n');
|
||||
|
||||
if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str());
|
||||
if (key == "move"sv) move = value;
|
||||
if (key == "score"sv) e.score = std::stoi(value);
|
||||
if (key == "ply"sv) e.ply = std::stoi(value);
|
||||
if (key == "result"sv) e.result = std::stoi(value);
|
||||
}
|
||||
|
||||
if (numProcessedPositionsBatch)
|
||||
{
|
||||
const auto cur = inputFile.tellg();
|
||||
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
inline void validateBin(std::string inputPath)
|
||||
{
|
||||
constexpr std::size_t reportSize = 1000000;
|
||||
|
||||
std::cout << "Validating " << inputPath << '\n';
|
||||
|
||||
std::ifstream inputFile(inputPath, std::ios_base::binary);
|
||||
const auto base = inputFile.tellg();
|
||||
std::size_t numProcessedPositions = 0;
|
||||
std::size_t numProcessedPositionsBatch = 0;
|
||||
|
||||
nodchip::PackedSfenValue psv;
|
||||
for(;;)
|
||||
{
|
||||
inputFile.read(reinterpret_cast<char*>(&psv), sizeof(psv));
|
||||
if (inputFile.gcount() != 40)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
auto e = packedSfenValueToTrainingDataEntry(psv);
|
||||
if (!e.isValid())
|
||||
{
|
||||
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
|
||||
return;
|
||||
}
|
||||
|
||||
++numProcessedPositions;
|
||||
++numProcessedPositionsBatch;
|
||||
|
||||
if (numProcessedPositionsBatch >= reportSize)
|
||||
{
|
||||
numProcessedPositionsBatch -= reportSize;
|
||||
const auto cur = inputFile.tellg();
|
||||
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (numProcessedPositionsBatch)
|
||||
{
|
||||
const auto cur = inputFile.tellg();
|
||||
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
inline void validateBinpack(std::string inputPath)
|
||||
{
|
||||
constexpr std::size_t reportSize = 1000000;
|
||||
|
||||
std::cout << "Validating " << inputPath << '\n';
|
||||
|
||||
CompressedTrainingDataEntryReader reader(inputPath);
|
||||
std::size_t numProcessedPositions = 0;
|
||||
std::size_t numProcessedPositionsBatch = 0;
|
||||
|
||||
while(reader.hasNext())
|
||||
{
|
||||
auto e = reader.next();
|
||||
if (!e.isValid())
|
||||
{
|
||||
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
|
||||
return;
|
||||
}
|
||||
|
||||
++numProcessedPositions;
|
||||
++numProcessedPositionsBatch;
|
||||
|
||||
if (numProcessedPositionsBatch >= reportSize)
|
||||
{
|
||||
numProcessedPositionsBatch -= reportSize;
|
||||
std::cout << "Processed " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (numProcessedPositionsBatch)
|
||||
{
|
||||
std::cout << "Processed " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
|
||||
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
#include "validate_training_data.h"
|
||||
|
||||
#include "uci.h"
|
||||
#include "misc.h"
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "tt.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
#include <iomanip>
|
||||
#include <list>
|
||||
#include <cmath> // std::exp(),std::pow(),std::log()
|
||||
#include <cstring> // memcpy()
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <filesystem>
|
||||
|
||||
using namespace std;
|
||||
namespace sys = std::filesystem;
|
||||
|
||||
namespace Stockfish::Tools
|
||||
{
|
||||
static inline const std::string plain_extension = ".plain";
|
||||
static inline const std::string bin_extension = ".bin";
|
||||
static inline const std::string binpack_extension = ".binpack";
|
||||
|
||||
static bool file_exists(const std::string& name)
|
||||
{
|
||||
std::ifstream f(name);
|
||||
return f.good();
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static bool is_validation_of_type(
|
||||
const std::string& input_path,
|
||||
const std::string& expected_input_extension)
|
||||
{
|
||||
return ends_with(input_path, expected_input_extension);
|
||||
}
|
||||
|
||||
using ValidateFunctionType = void(std::string inputPath);
|
||||
|
||||
static ValidateFunctionType* get_validate_function(const std::string& input_path)
|
||||
{
|
||||
if (is_validation_of_type(input_path, plain_extension))
|
||||
return binpack::validatePlain;
|
||||
|
||||
if (is_validation_of_type(input_path, bin_extension))
|
||||
return binpack::validateBin;
|
||||
|
||||
if (is_validation_of_type(input_path, binpack_extension))
|
||||
return binpack::validateBinpack;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void validate_training_data(const std::string& input_path)
|
||||
{
|
||||
if(!file_exists(input_path))
|
||||
{
|
||||
std::cerr << "Input file does not exist.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto func = get_validate_function(input_path);
|
||||
if (func != nullptr)
|
||||
{
|
||||
func(input_path);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Validation of files of this type is not supported.\n";
|
||||
}
|
||||
}
|
||||
|
||||
static void validate_training_data(const std::vector<std::string>& args)
|
||||
{
|
||||
if (args.size() != 1)
|
||||
{
|
||||
std::cerr << "Invalid arguments.\n";
|
||||
std::cerr << "Usage: validate in_path\n";
|
||||
return;
|
||||
}
|
||||
|
||||
validate_training_data(args[0]);
|
||||
}
|
||||
|
||||
void validate_training_data(istringstream& is)
|
||||
{
|
||||
std::vector<std::string> args;
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::string token = "";
|
||||
is >> token;
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
args.push_back(token);
|
||||
}
|
||||
|
||||
validate_training_data(args);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
#ifndef _VALIDATE_TRAINING_DATA_H_
|
||||
#define _VALIDATE_TRAINING_DATA_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace Stockfish::Tools {
|
||||
void validate_training_data(std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
|
@ -33,6 +33,7 @@
|
|||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
|
||||
#include "tools/validate_training_data.h"
|
||||
#include "tools/training_data_generator.h"
|
||||
#include "tools/training_data_generator_nonpv.h"
|
||||
#include "tools/convert.h"
|
||||
|
@ -330,6 +331,7 @@ void UCI::loop(int argc, char* argv[]) {
|
|||
else if (token == "generate_training_data") Tools::generate_training_data(is);
|
||||
else if (token == "generate_training_data") Tools::generate_training_data_nonpv(is);
|
||||
else if (token == "convert") Tools::convert(is);
|
||||
else if (token == "validate_training_data") Tools::validate_training_data(is);
|
||||
else if (token == "convert_bin") Tools::convert_bin(is);
|
||||
else if (token == "convert_plain") Tools::convert_plain(is);
|
||||
else if (token == "convert_bin_from_pgn_extract") Tools::convert_bin_from_pgn_extract(is);
|
||||
|
|
Loading…
Reference in New Issue