pytorch/caffe2/sgd/rowwise_adagrad_fused.h

1029 lines
31 KiB
C++

#pragma once
#include "caffe2/core/operator.h"
#include "caffe2/sgd/math_lp.h"
namespace caffe2 {
namespace internal {
inline float compute_square_average_inlined_(const float* a, int len) {
float sum = 0.0f;
int i = 0;
#ifdef __AVX__
constexpr int kSize = 8;
__m256 partial_sum = _mm256_setzero_ps();
for (; i + kSize <= len; i += kSize) {
__m256 ai = _mm256_loadu_ps(a + i);
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
}
// Reduce sum to 1 value
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
#endif
for (; i < len; ++i) {
sum = std::fma(a[i], a[i], sum);
}
return sum / len;
}
inline float compute_square_average_with_weight_decay_inlined_(
const float* a,
const float* w,
int len,
float weight_decay) {
float sum = 0.0f;
int i = 0;
#ifdef __AVX__
constexpr int kSize = 8;
__m256 partial_sum = _mm256_setzero_ps();
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
for (; i + kSize <= len; i += kSize) {
__m256 ai = _mm256_loadu_ps(a + i);
__m256 wi = _mm256_loadu_ps(w + i);
#ifdef __FMA__
ai = _mm256_fmadd_ps(weight_decay_v, wi, ai);
#else
ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai);
#endif
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
}
// Reduce sum to 1 value
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
#endif
for (; i < len; ++i) {
float ai = std::fma(weight_decay, w[i], a[i]);
sum = std::fma(ai, ai, sum);
}
return sum / len;
}
inline float compute_square_average_with_weight_decay_inlined_(
const float* a,
const at::Half* w,
int len,
float weight_decay) {
float sum = 0.0f;
int i = 0;
#ifdef __AVX__
constexpr int kSize = 8;
__m256 partial_sum = _mm256_setzero_ps();
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
for (; i + kSize <= len; i += kSize) {
__m256 ai = _mm256_loadu_ps(a + i);
__m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
__m256 wi = _mm256_cvtph_ps(whi);
#ifdef __FMA__
ai = _mm256_fmadd_ps(weight_decay_v, wi, ai);
#else
ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai);
#endif
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
}
// Reduce sum to 1 value
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
#endif
for (; i < len; ++i) {
float ai = std::fma(weight_decay, w[i], a[i]);
sum = std::fma(ai, ai, sum);
}
return sum / len;
}
} // namespace internal
/**
* Fused operator of
* SparseLengthsIndicesInGradientSumGradient (gradient of SparseLengthsSum) +
* RowWiseSparseAdagrad.
*
* BW saving analysis for numSegments B, L_avg = avg(lengths), block_size D,
* assuming T = float and SIndex = int64_t:
* Before fusion, SparseLengthsIndicesInGradientSumGradient reads B*D*4 and
* writes B*L_avg*D*4. RowWiseSparseAdagrad reads B*2*L_avg*D*4 and writes
* B*L_avg*D*4. So, the total memory traffic is B*(1+4*L_avg)*D*4 .
* After fusion, we read B*(1+L_avg)*D*4 and write B*L_avg*D*4 with total
* memory traffic B*(1+2*L_avg)*D*4.
* Assuming L_avg >> 1, the memory BW is saving is about 2x .
*
* See https://fb.quip.com/ldG7A55Ur5wM for more details on BW saving analysis
* and evaluation results.
*/
template <
typename Tdata, // embedding types
typename T, // everything else
typename TLengths,
typename rowWiseAdagradT,
bool is_mean = false>
class RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
: public Operator<CPUContext> {
public:
RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)),
counter_halflife_(
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
VLOG(1) << "gradient optimization operator in use: "
<< " weight_decay_=" << weight_decay_
<< " counter_halflife=" << counter_halflife_
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
const T decay = this->template GetSingleArgument<T>("decay", 1.0);
CAFFE_ENFORCE_EQ(
decay, 1.0, "Decay is not supported for SparseSimdAdagradOp");
}
bool RunOnDevice() override {
// Enforce shapes
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
CAFFE_ENFORCE_EQ(
Input(PARAM).size_from_dim(1),
Input(GRAD).size_from_dim(Input(INDICES).dim()));
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
}
template <typename SIndex>
bool DoRunWithType() {
const auto* lr = Input(LR).template data<T>();
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
auto& segmentGradsInput = Input(GRAD);
auto& lengthsInput = Input(LENGTHS);
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
auto numSegments = lengthsInput.size(0);
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
const auto* lengths = lengthsInput.template data<TLengths>();
auto n = Input(INDICES).numel();
auto numParams = Input(PARAM).numel();
const auto* indices = Input(INDICES).template data<SIndex>();
const auto* gradIn = segmentGradsInput.template data<T>();
const auto* paramIn = Input(PARAM).template data<Tdata>();
const auto* momentIn = Input(MOMENT_1).template data<T>();
const auto* count = counter_halflife_ == -1
? nullptr
: Input(COUNTER).template data<double>();
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
if (numSegments == 0) {
return true;
}
auto block_size = segmentGradsInput.size_from_dim(1);
// Enforce:
// Input(embedding/momentum) == outputs(embedding/momentum)
CAFFE_ENFORCE_EQ(
Input(PARAM).numel() / block_size,
Input(MOMENT_1).numel(),
"Input Param size: ",
Input(PARAM).numel(),
" Block size: ",
block_size,
" Input Moment size: ",
Input(MOMENT_1).numel());
if (is_mean) {
grad_buffer_.ResizeLike(Input(GRAD));
}
auto* grad_buffer_data =
is_mean ? grad_buffer_.template mutable_data<T>() : NULL;
if (is_mean) {
for (const auto rangeIndex : c10::irange(numSegments)) {
for (const auto tmpIndex : c10::irange(block_size)) {
auto offsetI = rangeIndex * block_size;
grad_buffer_data[offsetI + tmpIndex] = lengths[rangeIndex] > 0
? gradIn[offsetI + tmpIndex] / lengths[rangeIndex]
: gradIn[offsetI + tmpIndex];
}
}
}
compute<SIndex>(
block_size,
indices,
n,
lengths,
numSegments,
is_mean ? grad_buffer_data : gradIn,
paramIn,
numParams,
momentIn,
count,
paramOut,
momentOut,
epsilon_,
lr[0],
weight_decay_,
counter_halflife_,
kernel_);
return true;
}
template <typename SIndex, bool HAS_WEIGHT_DECAY>
static void compute(
int64_t block_size,
const SIndex* indices,
int64_t n,
const TLengths* lengths,
int64_t numSegments,
const T* gradIn,
const Tdata* paramIn,
int64_t numParams,
const T* momentIn,
const double* count,
Tdata* paramOut,
T* momentOut,
float epsilon,
T lr,
T weight_decay,
T counter_halflife,
rowWiseAdagradT& kernel) {
int dataIndex = 0;
for (const auto rangeIndex : c10::irange(numSegments)) {
auto offsetI = rangeIndex * block_size;
const float* g = gradIn + offsetI;
float g_sq_avg = 0;
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
}
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
++dataIndex) {
std::size_t idx = indices[dataIndex];
auto offsetIdx = idx * block_size;
// Enforce:
// access within range
// gradient access within range
CAFFE_ENFORCE_GE(
numParams,
block_size + offsetIdx,
"Accessing params out of bound, idx:",
idx,
" for input dataIndex:",
dataIndex,
" and block size:",
block_size,
" max size:",
numParams);
float freq = (counter_halflife > 0 && count[idx] > 0)
? counter_halflife / count[idx]
: 1.0;
if (block_size == 1) {
float gi = std::fma(weight_decay * freq, paramIn[idx], *g);
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
paramOut[idx] = paramIn[idx] + lr / (std::sqrt(hi) + epsilon) * gi;
} else {
// prefetching
const int prefdist_T0 = 16;
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
: dataIndex;
std::size_t idx_pref = indices[i_pref];
if (HAS_WEIGHT_DECAY) {
g_sq_avg =
internal::compute_square_average_with_weight_decay_inlined_(
g, paramOut + offsetIdx, block_size, weight_decay * freq);
}
kernel(
block_size,
paramOut + offsetIdx,
&paramOut[idx_pref * block_size],
g,
g_sq_avg,
momentOut + idx,
momentOut + idx_pref,
epsilon,
lr,
HAS_WEIGHT_DECAY ? weight_decay * freq : 0.0f);
}
}
}
CAFFE_ENFORCE_EQ(dataIndex, n);
}
template <typename SIndex>
static void compute(
int64_t block_size,
const SIndex* indices,
int64_t n,
const TLengths* lengths,
int64_t numSegments,
const T* gradIn,
const Tdata* paramIn,
int64_t numParams,
const T* momentIn,
const double* count,
Tdata* paramOut,
T* momentOut,
float epsilon,
T lr,
T weight_decay,
T counter_halflife,
rowWiseAdagradT& kernel) {
if (weight_decay == 0.0f) {
compute<SIndex, false>(
block_size,
indices,
n,
lengths,
numSegments,
gradIn,
paramIn,
numParams,
momentIn,
count,
paramOut,
momentOut,
epsilon,
lr,
0.0f,
/*counter_halflife=*/-1,
kernel);
} else {
compute<SIndex, true>(
block_size,
indices,
n,
lengths,
numSegments,
gradIn,
paramIn,
numParams,
momentIn,
count,
paramOut,
momentOut,
epsilon,
lr,
weight_decay,
counter_halflife,
kernel);
}
}
protected:
T epsilon_;
T weight_decay_;
T counter_halflife_;
rowWiseAdagradT kernel_;
Tensor grad_buffer_{CPU};
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, LENGTHS, COUNTER);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
};
template <
typename Tdata, // embedding types
typename T, // everything else
typename TLengths,
typename rowWiseAdagradT>
class RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final
: public Operator<CPUContext> {
public:
RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)),
counter_halflife_(
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
VLOG(1) << "gradient optimization operator in use: "
<< " weight_decay_=" << weight_decay_
<< " counter_halflife=" << counter_halflife_
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
}
bool RunOnDevice() override {
// Enforce shapes
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
CAFFE_ENFORCE_EQ(
Input(PARAM).size_from_dim(1),
Input(GRAD).size_from_dim(Input(INDICES).dim()));
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
}
template <typename SIndex>
bool DoRunWithType() {
const auto* lr = Input(LR).template data<T>();
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
auto& segmentGradsInput = Input(GRAD);
auto& lengthsInput = Input(LENGTHS);
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
auto numSegments = lengthsInput.size(0);
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
const auto* lengths = lengthsInput.template data<TLengths>();
auto n = Input(INDICES).numel();
auto numParams = Input(PARAM).numel();
const auto* indices = Input(INDICES).template data<SIndex>();
const auto* gradIn = segmentGradsInput.template data<T>();
const auto* paramIn = Input(PARAM).template data<Tdata>();
const auto* momentIn = Input(MOMENT_1).template data<T>();
const auto* auxParamIn = Input(AUX_PARAM).template data<T>();
const auto* count = counter_halflife_ == -1
? nullptr
: Input(COUNTER).template data<double>();
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
Output(AUX_GRAD)->Resize(n);
auto* auxGrad = Output(AUX_GRAD)->template mutable_data<T>();
CAFFE_ENFORCE_EQ(
paramIn, paramOut, "RowWiseSparseAdagrad must use inplace param");
CAFFE_ENFORCE_EQ(
momentIn, momentOut, "RowWiseSparseAdagrad must use inplace momentum");
if (numSegments == 0) {
return true;
}
auto block_size = segmentGradsInput.size_from_dim(1);
// Enforce:
// Input(embedding/momentum) == outputs(embedding/momentum)
CAFFE_ENFORCE_EQ(
Input(PARAM).numel() / block_size,
Input(MOMENT_1).numel(),
"Input Param size: ",
Input(PARAM).numel(),
" Block size: ",
block_size,
" Input Moment size: ",
Input(MOMENT_1).numel());
compute<SIndex>(
block_size,
indices,
n,
lengths,
numSegments,
gradIn,
paramIn,
numParams,
momentIn,
count,
auxParamIn,
paramOut,
momentOut,
auxGrad,
epsilon_,
lr[0],
weight_decay_,
counter_halflife_,
kernel_,
&context_);
return true;
}
template <typename SIndex, bool HAS_WEIGHT_DECAY>
static void compute(
int64_t block_size,
const SIndex* indices,
int64_t n,
const TLengths* lengths,
int64_t numSegments,
const T* gradIn,
const Tdata* paramIn,
int64_t numParams,
const T* momentIn,
const double* count,
const T* auxParamIn,
Tdata* paramOut,
T* momentOut,
T* auxGrad,
float epsilon,
T lr,
T weight_decay,
T counter_halflife,
rowWiseAdagradT& kernel,
CPUContext* context) {
// Cannot fuse this loop with the loop below because paramIn is updated
// by the second loop. Specifically, there could be dataIndex1 != dataIndex2
// s.t. indices[dataIndex1] == indices[dataIndex2], and fusing these two
// loops would violate dependencies w.r.t.
// paramIn[indices[dataIndex1]:block_size] The approximate version.
// (RowWiseSparseSimdAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp)
// ignores this dependency and fuses these two loops.
std::vector<T> temp_grad(block_size);
int dataIndex = 0;
for (const auto rangeIndex : c10::irange(numSegments)) {
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
++dataIndex) {
std::size_t idx = indices[dataIndex];
auto offsetI = rangeIndex * block_size;
auto offsetIdx = idx * block_size;
// Enforce:
// access within range
// gradient access within range
CAFFE_ENFORCE_GE(
numParams,
block_size + offsetIdx,
"Accessing params out of bound, idx:",
idx,
" for input dataIndex:",
dataIndex,
" and block size:",
block_size,
" max size:",
numParams);
// temp_aux_grad[dataIndex] = gradIn[offsetI] dot paramIn[offsetIdx]
internal::dot<T, Tdata, T>(
block_size,
gradIn + offsetI,
paramIn + offsetIdx,
auxGrad + dataIndex,
context);
}
}
CAFFE_ENFORCE_EQ(dataIndex, n);
dataIndex = 0;
for (const auto rangeIndex : c10::irange(numSegments)) {
auto offsetI = rangeIndex * block_size;
const float* g = gradIn + offsetI;
float g_sq_avg;
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
}
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
++dataIndex) {
auto idx = indices[dataIndex];
auto offsetIdx = idx * block_size;
auto localOffset = dataIndex - start;
for (const auto i : c10::irange(block_size)) {
temp_grad[i] = auxParamIn[localOffset] * g[i];
}
float freq = (counter_halflife > 0 && count[idx] > 0)
? counter_halflife / count[idx]
: 1.0;
if (block_size == 1) {
float gi = std::fma(weight_decay * freq, paramIn[idx], temp_grad[0]);
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
paramOut[idx] = paramIn[idx] + lr / (std::sqrt(hi) + epsilon) * gi;
} else {
// prefetching
const int prefdist_T0 = 16;
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
: dataIndex;
std::size_t idx_pref = indices[i_pref];
if (HAS_WEIGHT_DECAY) {
g_sq_avg =
internal::compute_square_average_with_weight_decay_inlined_(
temp_grad.data(),
paramOut + offsetIdx,
block_size,
weight_decay * freq);
}
kernel(
block_size,
paramOut + offsetIdx,
&paramOut[idx_pref * block_size],
temp_grad.data(),
g_sq_avg *
(HAS_WEIGHT_DECAY
? 1
: auxParamIn[localOffset] * auxParamIn[localOffset]),
momentOut + idx,
momentOut + idx_pref,
epsilon,
lr,
HAS_WEIGHT_DECAY ? weight_decay * freq : 0.0f);
}
}
}
}
template <typename SIndex>
static void compute(
int64_t block_size,
const SIndex* indices,
int64_t n,
const TLengths* lengths,
int64_t numSegments,
const T* gradIn,
const Tdata* paramIn,
int64_t numParams,
const T* momentIn,
const double* count,
const T* auxParamIn,
Tdata* paramOut,
T* momentOut,
T* auxGrad,
float epsilon,
T lr,
T weight_decay,
T counter_halflife,
rowWiseAdagradT& kernel,
CPUContext* context) {
if (weight_decay == 0.0f) {
compute<SIndex, /*HAS_WEIGHT_DECAY=*/false>(
block_size,
indices,
n,
lengths,
numSegments,
gradIn,
paramIn,
numParams,
momentIn,
count,
auxParamIn,
paramOut,
momentOut,
auxGrad,
epsilon,
lr,
0.0f,
/*counter_halflife=*/-1,
kernel,
context);
} else {
compute<SIndex, /*HAS_WEIGHT_DECAY=*/true>(
block_size,
indices,
n,
lengths,
numSegments,
gradIn,
paramIn,
numParams,
momentIn,
count,
auxParamIn,
paramOut,
momentOut,
auxGrad,
epsilon,
lr,
weight_decay,
counter_halflife,
kernel,
context);
}
}
protected:
T epsilon_;
T weight_decay_;
T counter_halflife_;
rowWiseAdagradT kernel_;
INPUT_TAGS(PARAM, MOMENT_1, AUX_PARAM, INDICES, GRAD, LR, LENGTHS, COUNTER);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, AUX_GRAD);
};
template <
typename Tdata, // embedding types
typename T, // everything else
typename TLengths,
typename rowWiseAdagradT>
class RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp
final : public Operator<CPUContext> {
public:
RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
weight_decay_(
this->template GetSingleArgument<float>("weight_decay", 0.f)),
counter_halflife_(
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
VLOG(1) << "gradient optimization operator in use: "
<< " weight_decay_=" << weight_decay_
<< " counter_halflife=" << counter_halflife_
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
const T decay = this->template GetSingleArgument<T>("decay", 1.0);
CAFFE_ENFORCE_EQ(
decay, 1.0, "Decay is not supported for SparseSimdAdagradOp");
}
bool RunOnDevice() override {
// Enforce shapes
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
CAFFE_ENFORCE_EQ(
Input(PARAM).size_from_dim(1),
Input(GRAD).size_from_dim(Input(INDICES).dim()));
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
}
template <typename SIndex>
bool DoRunWithType() {
if (weight_decay_ == 0.0f) {
return DoRunWithType<SIndex, false>();
} else {
return DoRunWithType<SIndex, true>();
}
}
template <typename SIndex, bool HAS_WEIGHT_DECAY>
bool DoRunWithType() {
const auto* lr = Input(LR).template data<T>();
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
auto& segmentGradsInput = Input(GRAD);
auto& lengthsInput = Input(LENGTHS);
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
auto numSegments = lengthsInput.size(0);
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
const auto* lengths = lengthsInput.template data<TLengths>();
auto n = Input(INDICES).numel();
const auto* indices = Input(INDICES).template data<SIndex>();
const auto* gradIn = segmentGradsInput.template data<T>();
const auto* paramIn = Input(PARAM).template data<Tdata>();
const auto* momentIn = Input(MOMENT_1).template data<T>();
const auto* count = counter_halflife_ == -1
? nullptr
: Input(COUNTER).template data<double>();
const auto* auxParamIn = Input(AUX_PARAM).template data<T>();
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
Output(AUX_GRAD)->Resize(n);
auto* auxGrad = Output(AUX_GRAD)->template mutable_data<T>();
CAFFE_ENFORCE_EQ(
paramIn, paramOut, "RowWiseSparseAdagrad must use inplace param");
CAFFE_ENFORCE_EQ(
momentIn, momentOut, "RowWiseSparseAdagrad must use inplace momentum");
if (numSegments == 0) {
return true;
}
auto block_size = segmentGradsInput.size_from_dim(1);
// Enforce:
// Input(embedding/momentum) == outputs(embedding/momentum)
CAFFE_ENFORCE_EQ(
Input(PARAM).numel() / block_size,
Input(MOMENT_1).numel(),
"Input Param size: ",
Input(PARAM).numel(),
" Block size: ",
block_size,
" Input Moment size: ",
Input(MOMENT_1).numel());
std::vector<T> temp_grad(block_size);
int dataIndex = 0;
for (const auto rangeIndex : c10::irange(numSegments)) {
auto offsetI = rangeIndex * block_size;
const float* g = gradIn + offsetI;
float g_sq_avg;
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
}
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
++dataIndex) {
std::size_t idx = indices[dataIndex];
auto offsetIdx = idx * block_size;
auto localOffset = dataIndex - start;
// Enforce:
// access within range
// gradient access within range
CAFFE_ENFORCE_GE(
Input(PARAM).numel(),
block_size + offsetIdx,
this->debug_def().input(PARAM),
", out of bound, idx:",
idx,
" for input dataIndex:",
dataIndex,
" and block size:",
block_size,
" max size:",
Input(PARAM).numel());
int i = 0;
float acc = 0.0f;
#ifdef __AVX__
constexpr int VLEN = 8;
__m256 acc_v = _mm256_setzero_ps();
__m256 scalar_v = _mm256_set1_ps(auxParamIn[localOffset]);
if (std::is_same<Tdata, float>::value) {
for (; i < block_size / VLEN * VLEN; i += VLEN) {
__m256 a_v = _mm256_loadu_ps(g + i);
__m256 b_v = _mm256_loadu_ps(
reinterpret_cast<const float*>(paramIn + offsetIdx + i));
__m256 c_v = _mm256_mul_ps(a_v, b_v);
acc_v = _mm256_add_ps(acc_v, c_v);
_mm256_storeu_ps(&temp_grad[i], _mm256_mul_ps(a_v, scalar_v));
}
} else if (std::is_same<Tdata, at::Half>::value) {
for (; i < block_size / VLEN * VLEN; i += VLEN) {
__m256 a_v = _mm256_loadu_ps(g + i);
__m256 b_v = _mm256_cvtph_ps(
_mm_load_si128((__m128i*)(paramIn + offsetIdx + i)));
__m256 c_v = _mm256_mul_ps(a_v, b_v);
acc_v = _mm256_add_ps(acc_v, c_v);
_mm256_storeu_ps(&temp_grad[i], _mm256_mul_ps(a_v, scalar_v));
}
} else {
CAFFE_THROW("Unsupported type for Embedding");
}
alignas(64) float temp[VLEN];
_mm256_store_ps(temp, acc_v);
for (const auto j : c10::irange(VLEN)) {
acc += temp[j];
}
#endif
for (; i < block_size; ++i) {
float a = g[i];
acc += a * paramIn[offsetIdx + i];
temp_grad[i] = a * auxParamIn[localOffset];
}
auxGrad[dataIndex] = acc;
float freq = (counter_halflife_ > 0 && count[idx] > 0)
? counter_halflife_ / count[idx]
: 1.0;
if (block_size == 1) {
float gi = std::fma(weight_decay_ * freq, paramIn[idx], temp_grad[0]);
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
paramOut[idx] =
paramIn[idx] + lr[0] / (std::sqrt(hi) + epsilon_) * gi;
} else {
// prefetching
const int prefdist_T0 = 16;
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
: dataIndex;
std::size_t idx_pref = indices[i_pref];
if (HAS_WEIGHT_DECAY) {
g_sq_avg =
internal::compute_square_average_with_weight_decay_inlined_(
temp_grad.data(),
paramOut + offsetIdx,
block_size,
weight_decay_ * freq);
}
kernel_(
block_size,
paramOut + offsetIdx,
&paramOut[idx_pref * block_size],
temp_grad.data(),
g_sq_avg *
(HAS_WEIGHT_DECAY
? 1
: auxParamIn[localOffset] * auxParamIn[localOffset]),
momentOut + idx,
momentOut + idx_pref,
epsilon_,
lr[0],
HAS_WEIGHT_DECAY ? weight_decay_ * freq : 0.0f);
}
}
}
CAFFE_ENFORCE_EQ(dataIndex, n);
return true;
}
protected:
T epsilon_;
T weight_decay_;
T counter_halflife_;
rowWiseAdagradT kernel_;
INPUT_TAGS(PARAM, MOMENT_1, AUX_PARAM, INDICES, GRAD, LR, LENGTHS, COUNTER);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, AUX_GRAD);
};
struct rowwise_adagrad_update_inlined {
void operator()(
int N,
float* w,
float* w_n, // prefetch ptr
const float* g,
float g_sq_avg,
float* h,
float* h_n, // prefetch ptr
float epsilon,
float lr,
float weight_decay) {
#ifdef __AVX__
constexpr int kSize = 8;
_mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0);
#endif
float hi = *h = *h + g_sq_avg;
float float_step = lr / (std::sqrt(hi) + epsilon);
int i = 0;
#ifdef __AVX__
__m256 step = _mm256_set1_ps(float_step);
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
for (i = 0; i + kSize <= N; i += kSize) {
_mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
__m256 gi = _mm256_loadu_ps(g + i);
__m256 wi = _mm256_loadu_ps(w + i);
if (weight_decay != 0.0f) {
#ifdef __FMA__
gi = _mm256_fmadd_ps(weight_decay_v, wi, gi);
#else
gi = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), gi);
#endif
}
_mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
}
#endif
for (; i < N; ++i) {
float gi =
weight_decay != 0.0f ? std::fma(weight_decay, w[i], g[i]) : g[i];
w[i] = w[i] + gi * float_step;
}
}
};
} // namespace caffe2