1
0
Fork 0

Add dedicated command for training data validation.

pull/3497/head
Tomasz Sobczyk 2021-05-24 19:43:07 +02:00
parent 5676a50807
commit eac1d430b4
6 changed files with 299 additions and 0 deletions

View File

@ -0,0 +1,12 @@
# validate_training_data
`validate_training_data` allows validation of training data of types `.plain`, `.bin`, and `.binpack`.
As all commands in stockfish `validate_training_data` can be invoked either from command line (as `stockfish.exe validate_training_data ...`) or in the interactive prompt.
The syntax of this command is as follows:
```
validate_training_data in_path
```
`in_path` is the path to the file to validate. The type of the data is deduced based on its extension (one of `.plain`, `.bin`, `.binpack`).

View File

@ -51,6 +51,7 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
nnue/evaluate_nnue.cpp \
nnue/features/half_ka_v2.cpp \
tools/validate_training_data.cpp \
tools/sfen_packer.cpp \
tools/training_data_generator.cpp \
tools/training_data_generator_nonpv.cpp \

View File

@ -7831,4 +7831,154 @@ namespace binpack
std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n";
}
inline void validatePlain(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
TrainingDataEntry e;
std::string key;
std::string value;
std::string move;
std::ifstream inputFile(inputPath);
const auto base = inputFile.tellg();
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
for(;;)
{
inputFile >> key;
if (!inputFile)
{
break;
}
if (key == "e"sv)
{
e.move = chess::uci::uciToMove(e.pos, move);
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
continue;
}
inputFile >> std::ws;
std::getline(inputFile, value, '\n');
if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str());
if (key == "move"sv) move = value;
if (key == "score"sv) e.score = std::stoi(value);
if (key == "ply"sv) e.ply = std::stoi(value);
if (key == "result"sv) e.result = std::stoi(value);
}
if (numProcessedPositionsBatch)
{
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
inline void validateBin(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
std::ifstream inputFile(inputPath, std::ios_base::binary);
const auto base = inputFile.tellg();
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
nodchip::PackedSfenValue psv;
for(;;)
{
inputFile.read(reinterpret_cast<char*>(&psv), sizeof(psv));
if (inputFile.gcount() != 40)
{
break;
}
auto e = packedSfenValueToTrainingDataEntry(psv);
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
}
if (numProcessedPositionsBatch)
{
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
inline void validateBinpack(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
CompressedTrainingDataEntryReader reader(inputPath);
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
while(reader.hasNext())
{
auto e = reader.next();
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
std::cout << "Processed " << numProcessedPositions << " positions.\n";
}
}
if (numProcessedPositionsBatch)
{
std::cout << "Processed " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
}

View File

@ -0,0 +1,122 @@
#include "validate_training_data.h"
#include "uci.h"
#include "misc.h"
#include "thread.h"
#include "position.h"
#include "tt.h"
#include "extra/nnue_data_binpack_format.h"
#include "nnue/evaluate_nnue.h"
#include "syzygy/tbprobe.h"
#include <sstream>
#include <fstream>
#include <unordered_set>
#include <iomanip>
#include <list>
#include <cmath> // std::exp(),std::pow(),std::log()
#include <cstring> // memcpy()
#include <memory>
#include <limits>
#include <optional>
#include <chrono>
#include <random>
#include <regex>
#include <filesystem>
using namespace std;
namespace sys = std::filesystem;
namespace Stockfish::Tools
{
static inline const std::string plain_extension = ".plain";
static inline const std::string bin_extension = ".bin";
static inline const std::string binpack_extension = ".binpack";
static bool file_exists(const std::string& name)
{
std::ifstream f(name);
return f.good();
}
static bool ends_with(const std::string& lhs, const std::string& end)
{
if (end.size() > lhs.size()) return false;
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
}
static bool is_validation_of_type(
const std::string& input_path,
const std::string& expected_input_extension)
{
return ends_with(input_path, expected_input_extension);
}
using ValidateFunctionType = void(std::string inputPath);
static ValidateFunctionType* get_validate_function(const std::string& input_path)
{
if (is_validation_of_type(input_path, plain_extension))
return binpack::validatePlain;
if (is_validation_of_type(input_path, bin_extension))
return binpack::validateBin;
if (is_validation_of_type(input_path, binpack_extension))
return binpack::validateBinpack;
return nullptr;
}
static void validate_training_data(const std::string& input_path)
{
if(!file_exists(input_path))
{
std::cerr << "Input file does not exist.\n";
return;
}
auto func = get_validate_function(input_path);
if (func != nullptr)
{
func(input_path);
}
else
{
std::cerr << "Validation of files of this type is not supported.\n";
}
}
static void validate_training_data(const std::vector<std::string>& args)
{
if (args.size() != 1)
{
std::cerr << "Invalid arguments.\n";
std::cerr << "Usage: validate in_path\n";
return;
}
validate_training_data(args[0]);
}
void validate_training_data(istringstream& is)
{
std::vector<std::string> args;
while (true)
{
std::string token = "";
is >> token;
if (token == "")
break;
args.push_back(token);
}
validate_training_data(args);
}
}

View File

@ -0,0 +1,12 @@
#ifndef _VALIDATE_TRAINING_DATA_H_
#define _VALIDATE_TRAINING_DATA_H_
#include <vector>
#include <string>
#include <sstream>
namespace Stockfish::Tools {
void validate_training_data(std::istringstream& is);
}
#endif

View File

@ -33,6 +33,7 @@
#include "tt.h"
#include "uci.h"
#include "tools/validate_training_data.h"
#include "tools/training_data_generator.h"
#include "tools/training_data_generator_nonpv.h"
#include "tools/convert.h"
@ -330,6 +331,7 @@ void UCI::loop(int argc, char* argv[]) {
else if (token == "generate_training_data") Tools::generate_training_data(is);
else if (token == "generate_training_data") Tools::generate_training_data_nonpv(is);
else if (token == "convert") Tools::convert(is);
else if (token == "validate_training_data") Tools::validate_training_data(is);
else if (token == "convert_bin") Tools::convert_bin(is);
else if (token == "convert_plain") Tools::convert_plain(is);
else if (token == "convert_bin_from_pgn_extract") Tools::convert_bin_from_pgn_extract(is);