Avoid unnecessary stores in the affine transform. See https://godbolt.org/z/59aTKbbYc
parent
237ed1ef8f
commit
db0f46844a
|
@ -251,9 +251,6 @@ namespace Stockfish::Eval::NNUE::Layers {
|
|||
#endif
|
||||
|
||||
#if defined (USE_SSSE3)
|
||||
// Different layout, we process 4 inputs at a time, always.
|
||||
static_assert(InputDimensions % 4 == 0);
|
||||
|
||||
const auto output = reinterpret_cast<OutputType*>(buffer);
|
||||
const auto inputVector = reinterpret_cast<const vec_t*>(input);
|
||||
|
||||
|
@ -263,13 +260,18 @@ namespace Stockfish::Eval::NNUE::Layers {
|
|||
// because then it is also an input dimension.
|
||||
if constexpr (OutputDimensions % OutputSimdWidth == 0)
|
||||
{
|
||||
static_assert(InputDimensions % 16 == 0);
|
||||
|
||||
constexpr IndexType NumChunks = InputDimensions / 4;
|
||||
constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;
|
||||
|
||||
const auto input32 = reinterpret_cast<const std::int32_t*>(input);
|
||||
vec_t* outptr = reinterpret_cast<vec_t*>(output);
|
||||
std::memcpy(output, biases, OutputDimensions * sizeof(OutputType));
|
||||
const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);
|
||||
vec_t outs[NumRegs];
|
||||
for (IndexType k = 0; k < NumRegs; ++k)
|
||||
outs[k] = biasvec[k];
|
||||
|
||||
for (int i = 0; i < (int)NumChunks - 3; i += 4)
|
||||
for (IndexType i = 0; i < NumChunks; i += 4)
|
||||
{
|
||||
const vec_t in0 = vec_set_32(input32[i + 0]);
|
||||
const vec_t in1 = vec_set_32(input32[i + 1]);
|
||||
|
@ -279,12 +281,18 @@ namespace Stockfish::Eval::NNUE::Layers {
|
|||
const auto col1 = reinterpret_cast<const vec_t*>(&weights[(i + 1) * OutputDimensions * 4]);
|
||||
const auto col2 = reinterpret_cast<const vec_t*>(&weights[(i + 2) * OutputDimensions * 4]);
|
||||
const auto col3 = reinterpret_cast<const vec_t*>(&weights[(i + 3) * OutputDimensions * 4]);
|
||||
for (int j = 0; j * OutputSimdWidth < OutputDimensions; ++j)
|
||||
vec_add_dpbusd_32x4(outptr[j], in0, col0[j], in1, col1[j], in2, col2[j], in3, col3[j]);
|
||||
for (IndexType k = 0; k < NumRegs; ++k)
|
||||
vec_add_dpbusd_32x4(outs[k], in0, col0[k], in1, col1[k], in2, col2[k], in3, col3[k]);
|
||||
}
|
||||
|
||||
vec_t* outptr = reinterpret_cast<vec_t*>(output);
|
||||
for (IndexType k = 0; k < NumRegs; ++k)
|
||||
outptr[k] = outs[k];
|
||||
}
|
||||
else if constexpr (OutputDimensions == 1)
|
||||
{
|
||||
static_assert(InputDimensions % 4 == 0);
|
||||
|
||||
#if defined (USE_AVX512)
|
||||
if constexpr (PaddedInputDimensions % (SimdWidth * 2) != 0)
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue