1
0
Fork 0

Avoid unnecessary stores in the affine transform. See https://godbolt.org/z/59aTKbbYc

pull/3634/head
Tomasz Sobczyk 2021-07-27 22:12:14 +02:00
parent 237ed1ef8f
commit db0f46844a
1 changed files with 16 additions and 8 deletions

View File

@ -251,9 +251,6 @@ namespace Stockfish::Eval::NNUE::Layers {
#endif
#if defined (USE_SSSE3)
// Different layout, we process 4 inputs at a time, always.
static_assert(InputDimensions % 4 == 0);
const auto output = reinterpret_cast<OutputType*>(buffer);
const auto inputVector = reinterpret_cast<const vec_t*>(input);
@ -263,13 +260,18 @@ namespace Stockfish::Eval::NNUE::Layers {
// because then it is also an input dimension.
if constexpr (OutputDimensions % OutputSimdWidth == 0)
{
static_assert(InputDimensions % 16 == 0);
constexpr IndexType NumChunks = InputDimensions / 4;
constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
const auto input32 = reinterpret_cast<const std::int32_t*>(input);
vec_t* outptr = reinterpret_cast<vec_t*>(output);
std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
vec_t outs[NumRegs];
for (IndexType k = 0; k < NumRegs; ++k)
outs[k] = biasvec[k];
for (int i = 0; i < (int)NumChunks - 3; i += 4)
for (IndexType i = 0; i < NumChunks; i += 4)
{
const vec_t in0 = vec_set_32(input32[i + 0]);
const vec_t in1 = vec_set_32(input32[i + 1]);
@ -279,12 +281,18 @@ namespace Stockfish::Eval::NNUE::Layers {
const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
const auto col2 = reinterpret_cast<const vec_t*>(&weights[(i + 2) * OutputDimensions * 4]);
const auto col3 = reinterpret_cast<const vec_t*>(&weights[(i + 3) * OutputDimensions * 4]);
for (int j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]);
for (IndexType k = 0; k < NumRegs; ++k)
vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]);
}
vec_t* outptr = reinterpret_cast<vec_t*>(output);
for (IndexType k = 0; k < NumRegs; ++k)
outptr[k] = outs[k];
}
else if constexpr (OutputDimensions == 1)
{
static_assert(InputDimensions % 4 == 0);
#if defined (USE_AVX512)
if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0)
{