1029 lines
31 KiB
C++
1029 lines
31 KiB
C++
#pragma once
|
|
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/sgd/math_lp.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace internal {
|
|
inline float compute_square_average_inlined_(const float* a, int len) {
|
|
float sum = 0.0f;
|
|
|
|
int i = 0;
|
|
#ifdef __AVX__
|
|
constexpr int kSize = 8;
|
|
__m256 partial_sum = _mm256_setzero_ps();
|
|
for (; i + kSize <= len; i += kSize) {
|
|
__m256 ai = _mm256_loadu_ps(a + i);
|
|
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
|
|
}
|
|
// Reduce sum to 1 value
|
|
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
|
|
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
|
|
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
|
|
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
|
|
#endif
|
|
|
|
for (; i < len; ++i) {
|
|
sum = std::fma(a[i], a[i], sum);
|
|
}
|
|
|
|
return sum / len;
|
|
}
|
|
|
|
inline float compute_square_average_with_weight_decay_inlined_(
|
|
const float* a,
|
|
const float* w,
|
|
int len,
|
|
float weight_decay) {
|
|
float sum = 0.0f;
|
|
|
|
int i = 0;
|
|
#ifdef __AVX__
|
|
constexpr int kSize = 8;
|
|
__m256 partial_sum = _mm256_setzero_ps();
|
|
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
|
|
for (; i + kSize <= len; i += kSize) {
|
|
__m256 ai = _mm256_loadu_ps(a + i);
|
|
__m256 wi = _mm256_loadu_ps(w + i);
|
|
#ifdef __FMA__
|
|
ai = _mm256_fmadd_ps(weight_decay_v, wi, ai);
|
|
#else
|
|
ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai);
|
|
#endif
|
|
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
|
|
}
|
|
// Reduce sum to 1 value
|
|
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
|
|
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
|
|
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
|
|
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
|
|
#endif
|
|
|
|
for (; i < len; ++i) {
|
|
float ai = std::fma(weight_decay, w[i], a[i]);
|
|
sum = std::fma(ai, ai, sum);
|
|
}
|
|
|
|
return sum / len;
|
|
}
|
|
|
|
inline float compute_square_average_with_weight_decay_inlined_(
|
|
const float* a,
|
|
const at::Half* w,
|
|
int len,
|
|
float weight_decay) {
|
|
float sum = 0.0f;
|
|
|
|
int i = 0;
|
|
#ifdef __AVX__
|
|
constexpr int kSize = 8;
|
|
__m256 partial_sum = _mm256_setzero_ps();
|
|
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
|
|
for (; i + kSize <= len; i += kSize) {
|
|
__m256 ai = _mm256_loadu_ps(a + i);
|
|
__m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
|
|
__m256 wi = _mm256_cvtph_ps(whi);
|
|
#ifdef __FMA__
|
|
ai = _mm256_fmadd_ps(weight_decay_v, wi, ai);
|
|
#else
|
|
ai = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), ai);
|
|
#endif
|
|
partial_sum = _mm256_add_ps(partial_sum, _mm256_mul_ps(ai, ai));
|
|
}
|
|
// Reduce sum to 1 value
|
|
__m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
|
|
__m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
|
|
sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
|
|
_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
|
|
#endif
|
|
|
|
for (; i < len; ++i) {
|
|
float ai = std::fma(weight_decay, w[i], a[i]);
|
|
sum = std::fma(ai, ai, sum);
|
|
}
|
|
|
|
return sum / len;
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
/**
|
|
* Fused operator of
|
|
* SparseLengthsIndicesInGradientSumGradient (gradient of SparseLengthsSum) +
|
|
* RowWiseSparseAdagrad.
|
|
*
|
|
* BW saving analysis for numSegments B, L_avg = avg(lengths), block_size D,
|
|
* assuming T = float and SIndex = int64_t:
|
|
* Before fusion, SparseLengthsIndicesInGradientSumGradient reads B*D*4 and
|
|
* writes B*L_avg*D*4. RowWiseSparseAdagrad reads B*2*L_avg*D*4 and writes
|
|
* B*L_avg*D*4. So, the total memory traffic is B*(1+4*L_avg)*D*4 .
|
|
* After fusion, we read B*(1+L_avg)*D*4 and write B*L_avg*D*4 with total
|
|
* memory traffic B*(1+2*L_avg)*D*4.
|
|
* Assuming L_avg >> 1, the memory BW is saving is about 2x .
|
|
*
|
|
* See https://fb.quip.com/ldG7A55Ur5wM for more details on BW saving analysis
|
|
* and evaluation results.
|
|
*/
|
|
template <
|
|
typename Tdata, // embedding types
|
|
typename T, // everything else
|
|
typename TLengths,
|
|
typename rowWiseAdagradT,
|
|
bool is_mean = false>
|
|
class RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp final
|
|
: public Operator<CPUContext> {
|
|
public:
|
|
RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp(
|
|
const OperatorDef& operator_def,
|
|
Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws),
|
|
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
|
|
weight_decay_(
|
|
this->template GetSingleArgument<float>("weight_decay", 0.f)),
|
|
counter_halflife_(
|
|
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
|
|
VLOG(1) << "gradient optimization operator in use: "
|
|
<< " weight_decay_=" << weight_decay_
|
|
<< " counter_halflife=" << counter_halflife_
|
|
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
|
|
const T decay = this->template GetSingleArgument<T>("decay", 1.0);
|
|
CAFFE_ENFORCE_EQ(
|
|
decay, 1.0, "Decay is not supported for SparseSimdAdagradOp");
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
// Enforce shapes
|
|
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
|
|
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).size_from_dim(1),
|
|
Input(GRAD).size_from_dim(Input(INDICES).dim()));
|
|
|
|
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
|
this, Input(INDICES));
|
|
}
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType() {
|
|
const auto* lr = Input(LR).template data<T>();
|
|
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
|
|
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
|
|
|
|
auto& segmentGradsInput = Input(GRAD);
|
|
auto& lengthsInput = Input(LENGTHS);
|
|
|
|
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
|
|
auto numSegments = lengthsInput.size(0);
|
|
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
|
|
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
|
|
const auto* lengths = lengthsInput.template data<TLengths>();
|
|
|
|
auto n = Input(INDICES).numel();
|
|
auto numParams = Input(PARAM).numel();
|
|
|
|
const auto* indices = Input(INDICES).template data<SIndex>();
|
|
const auto* gradIn = segmentGradsInput.template data<T>();
|
|
const auto* paramIn = Input(PARAM).template data<Tdata>();
|
|
const auto* momentIn = Input(MOMENT_1).template data<T>();
|
|
const auto* count = counter_halflife_ == -1
|
|
? nullptr
|
|
: Input(COUNTER).template data<double>();
|
|
|
|
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
|
|
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
|
|
|
|
if (numSegments == 0) {
|
|
return true;
|
|
}
|
|
|
|
auto block_size = segmentGradsInput.size_from_dim(1);
|
|
|
|
// Enforce:
|
|
// Input(embedding/momentum) == outputs(embedding/momentum)
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).numel() / block_size,
|
|
Input(MOMENT_1).numel(),
|
|
"Input Param size: ",
|
|
Input(PARAM).numel(),
|
|
" Block size: ",
|
|
block_size,
|
|
" Input Moment size: ",
|
|
Input(MOMENT_1).numel());
|
|
|
|
if (is_mean) {
|
|
grad_buffer_.ResizeLike(Input(GRAD));
|
|
}
|
|
auto* grad_buffer_data =
|
|
is_mean ? grad_buffer_.template mutable_data<T>() : NULL;
|
|
if (is_mean) {
|
|
for (const auto rangeIndex : c10::irange(numSegments)) {
|
|
for (const auto tmpIndex : c10::irange(block_size)) {
|
|
auto offsetI = rangeIndex * block_size;
|
|
grad_buffer_data[offsetI + tmpIndex] = lengths[rangeIndex] > 0
|
|
? gradIn[offsetI + tmpIndex] / lengths[rangeIndex]
|
|
: gradIn[offsetI + tmpIndex];
|
|
}
|
|
}
|
|
}
|
|
|
|
compute<SIndex>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
is_mean ? grad_buffer_data : gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
paramOut,
|
|
momentOut,
|
|
epsilon_,
|
|
lr[0],
|
|
weight_decay_,
|
|
counter_halflife_,
|
|
kernel_);
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename SIndex, bool HAS_WEIGHT_DECAY>
|
|
static void compute(
|
|
int64_t block_size,
|
|
const SIndex* indices,
|
|
int64_t n,
|
|
const TLengths* lengths,
|
|
int64_t numSegments,
|
|
const T* gradIn,
|
|
const Tdata* paramIn,
|
|
int64_t numParams,
|
|
const T* momentIn,
|
|
const double* count,
|
|
Tdata* paramOut,
|
|
T* momentOut,
|
|
float epsilon,
|
|
T lr,
|
|
T weight_decay,
|
|
T counter_halflife,
|
|
rowWiseAdagradT& kernel) {
|
|
int dataIndex = 0;
|
|
for (const auto rangeIndex : c10::irange(numSegments)) {
|
|
auto offsetI = rangeIndex * block_size;
|
|
const float* g = gradIn + offsetI;
|
|
|
|
float g_sq_avg = 0;
|
|
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
|
|
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
|
|
}
|
|
|
|
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
|
|
++dataIndex) {
|
|
std::size_t idx = indices[dataIndex];
|
|
auto offsetIdx = idx * block_size;
|
|
|
|
// Enforce:
|
|
// access within range
|
|
// gradient access within range
|
|
CAFFE_ENFORCE_GE(
|
|
numParams,
|
|
block_size + offsetIdx,
|
|
"Accessing params out of bound, idx:",
|
|
idx,
|
|
" for input dataIndex:",
|
|
dataIndex,
|
|
" and block size:",
|
|
block_size,
|
|
" max size:",
|
|
numParams);
|
|
|
|
float freq = (counter_halflife > 0 && count[idx] > 0)
|
|
? counter_halflife / count[idx]
|
|
: 1.0;
|
|
|
|
if (block_size == 1) {
|
|
float gi = std::fma(weight_decay * freq, paramIn[idx], *g);
|
|
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
|
|
paramOut[idx] = paramIn[idx] + lr / (std::sqrt(hi) + epsilon) * gi;
|
|
} else {
|
|
// prefetching
|
|
const int prefdist_T0 = 16;
|
|
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
|
|
: dataIndex;
|
|
std::size_t idx_pref = indices[i_pref];
|
|
|
|
if (HAS_WEIGHT_DECAY) {
|
|
g_sq_avg =
|
|
internal::compute_square_average_with_weight_decay_inlined_(
|
|
g, paramOut + offsetIdx, block_size, weight_decay * freq);
|
|
}
|
|
|
|
kernel(
|
|
block_size,
|
|
|
|
paramOut + offsetIdx,
|
|
¶mOut[idx_pref * block_size],
|
|
|
|
g,
|
|
g_sq_avg,
|
|
|
|
momentOut + idx,
|
|
momentOut + idx_pref,
|
|
|
|
epsilon,
|
|
lr,
|
|
HAS_WEIGHT_DECAY ? weight_decay * freq : 0.0f);
|
|
}
|
|
}
|
|
}
|
|
CAFFE_ENFORCE_EQ(dataIndex, n);
|
|
}
|
|
|
|
template <typename SIndex>
|
|
static void compute(
|
|
int64_t block_size,
|
|
const SIndex* indices,
|
|
int64_t n,
|
|
const TLengths* lengths,
|
|
int64_t numSegments,
|
|
const T* gradIn,
|
|
const Tdata* paramIn,
|
|
int64_t numParams,
|
|
const T* momentIn,
|
|
const double* count,
|
|
Tdata* paramOut,
|
|
T* momentOut,
|
|
float epsilon,
|
|
T lr,
|
|
T weight_decay,
|
|
T counter_halflife,
|
|
rowWiseAdagradT& kernel) {
|
|
if (weight_decay == 0.0f) {
|
|
compute<SIndex, false>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
paramOut,
|
|
momentOut,
|
|
epsilon,
|
|
lr,
|
|
0.0f,
|
|
/*counter_halflife=*/-1,
|
|
kernel);
|
|
} else {
|
|
compute<SIndex, true>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
paramOut,
|
|
momentOut,
|
|
epsilon,
|
|
lr,
|
|
weight_decay,
|
|
counter_halflife,
|
|
kernel);
|
|
}
|
|
}
|
|
|
|
protected:
|
|
T epsilon_;
|
|
T weight_decay_;
|
|
T counter_halflife_;
|
|
rowWiseAdagradT kernel_;
|
|
Tensor grad_buffer_{CPU};
|
|
|
|
INPUT_TAGS(PARAM, MOMENT_1, INDICES, GRAD, LR, LENGTHS, COUNTER);
|
|
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1);
|
|
};
|
|
|
|
template <
|
|
typename Tdata, // embedding types
|
|
typename T, // everything else
|
|
typename TLengths,
|
|
typename rowWiseAdagradT>
|
|
class RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp final
|
|
: public Operator<CPUContext> {
|
|
public:
|
|
RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientOp(
|
|
const OperatorDef& operator_def,
|
|
Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws),
|
|
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
|
|
weight_decay_(
|
|
this->template GetSingleArgument<float>("weight_decay", 0.f)),
|
|
counter_halflife_(
|
|
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
|
|
VLOG(1) << "gradient optimization operator in use: "
|
|
<< " weight_decay_=" << weight_decay_
|
|
<< " counter_halflife=" << counter_halflife_
|
|
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
// Enforce shapes
|
|
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
|
|
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).size_from_dim(1),
|
|
Input(GRAD).size_from_dim(Input(INDICES).dim()));
|
|
|
|
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
|
this, Input(INDICES));
|
|
}
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType() {
|
|
const auto* lr = Input(LR).template data<T>();
|
|
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
|
|
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
|
|
|
|
auto& segmentGradsInput = Input(GRAD);
|
|
auto& lengthsInput = Input(LENGTHS);
|
|
|
|
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
|
|
auto numSegments = lengthsInput.size(0);
|
|
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
|
|
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
|
|
const auto* lengths = lengthsInput.template data<TLengths>();
|
|
|
|
auto n = Input(INDICES).numel();
|
|
auto numParams = Input(PARAM).numel();
|
|
|
|
const auto* indices = Input(INDICES).template data<SIndex>();
|
|
const auto* gradIn = segmentGradsInput.template data<T>();
|
|
const auto* paramIn = Input(PARAM).template data<Tdata>();
|
|
const auto* momentIn = Input(MOMENT_1).template data<T>();
|
|
const auto* auxParamIn = Input(AUX_PARAM).template data<T>();
|
|
const auto* count = counter_halflife_ == -1
|
|
? nullptr
|
|
: Input(COUNTER).template data<double>();
|
|
|
|
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
|
|
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
|
|
Output(AUX_GRAD)->Resize(n);
|
|
auto* auxGrad = Output(AUX_GRAD)->template mutable_data<T>();
|
|
|
|
CAFFE_ENFORCE_EQ(
|
|
paramIn, paramOut, "RowWiseSparseAdagrad must use inplace param");
|
|
CAFFE_ENFORCE_EQ(
|
|
momentIn, momentOut, "RowWiseSparseAdagrad must use inplace momentum");
|
|
|
|
if (numSegments == 0) {
|
|
return true;
|
|
}
|
|
|
|
auto block_size = segmentGradsInput.size_from_dim(1);
|
|
|
|
// Enforce:
|
|
// Input(embedding/momentum) == outputs(embedding/momentum)
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).numel() / block_size,
|
|
Input(MOMENT_1).numel(),
|
|
"Input Param size: ",
|
|
Input(PARAM).numel(),
|
|
" Block size: ",
|
|
block_size,
|
|
" Input Moment size: ",
|
|
Input(MOMENT_1).numel());
|
|
|
|
compute<SIndex>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
auxParamIn,
|
|
paramOut,
|
|
momentOut,
|
|
auxGrad,
|
|
epsilon_,
|
|
lr[0],
|
|
weight_decay_,
|
|
counter_halflife_,
|
|
kernel_,
|
|
&context_);
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename SIndex, bool HAS_WEIGHT_DECAY>
|
|
static void compute(
|
|
int64_t block_size,
|
|
const SIndex* indices,
|
|
int64_t n,
|
|
const TLengths* lengths,
|
|
int64_t numSegments,
|
|
const T* gradIn,
|
|
const Tdata* paramIn,
|
|
int64_t numParams,
|
|
const T* momentIn,
|
|
const double* count,
|
|
const T* auxParamIn,
|
|
Tdata* paramOut,
|
|
T* momentOut,
|
|
T* auxGrad,
|
|
float epsilon,
|
|
T lr,
|
|
T weight_decay,
|
|
T counter_halflife,
|
|
rowWiseAdagradT& kernel,
|
|
CPUContext* context) {
|
|
// Cannot fuse this loop with the loop below because paramIn is updated
|
|
// by the second loop. Specifically, there could be dataIndex1 != dataIndex2
|
|
// s.t. indices[dataIndex1] == indices[dataIndex2], and fusing these two
|
|
// loops would violate dependencies w.r.t.
|
|
// paramIn[indices[dataIndex1]:block_size] The approximate version.
|
|
// (RowWiseSparseSimdAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp)
|
|
// ignores this dependency and fuses these two loops.
|
|
std::vector<T> temp_grad(block_size);
|
|
int dataIndex = 0;
|
|
for (const auto rangeIndex : c10::irange(numSegments)) {
|
|
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
|
|
++dataIndex) {
|
|
std::size_t idx = indices[dataIndex];
|
|
auto offsetI = rangeIndex * block_size;
|
|
auto offsetIdx = idx * block_size;
|
|
|
|
// Enforce:
|
|
// access within range
|
|
// gradient access within range
|
|
CAFFE_ENFORCE_GE(
|
|
numParams,
|
|
block_size + offsetIdx,
|
|
"Accessing params out of bound, idx:",
|
|
idx,
|
|
" for input dataIndex:",
|
|
dataIndex,
|
|
" and block size:",
|
|
block_size,
|
|
" max size:",
|
|
numParams);
|
|
|
|
// temp_aux_grad[dataIndex] = gradIn[offsetI] dot paramIn[offsetIdx]
|
|
internal::dot<T, Tdata, T>(
|
|
block_size,
|
|
gradIn + offsetI,
|
|
paramIn + offsetIdx,
|
|
auxGrad + dataIndex,
|
|
context);
|
|
}
|
|
}
|
|
CAFFE_ENFORCE_EQ(dataIndex, n);
|
|
|
|
dataIndex = 0;
|
|
for (const auto rangeIndex : c10::irange(numSegments)) {
|
|
auto offsetI = rangeIndex * block_size;
|
|
const float* g = gradIn + offsetI;
|
|
|
|
float g_sq_avg;
|
|
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
|
|
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
|
|
}
|
|
|
|
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
|
|
++dataIndex) {
|
|
auto idx = indices[dataIndex];
|
|
auto offsetIdx = idx * block_size;
|
|
auto localOffset = dataIndex - start;
|
|
|
|
for (const auto i : c10::irange(block_size)) {
|
|
temp_grad[i] = auxParamIn[localOffset] * g[i];
|
|
}
|
|
|
|
float freq = (counter_halflife > 0 && count[idx] > 0)
|
|
? counter_halflife / count[idx]
|
|
: 1.0;
|
|
|
|
if (block_size == 1) {
|
|
float gi = std::fma(weight_decay * freq, paramIn[idx], temp_grad[0]);
|
|
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
|
|
paramOut[idx] = paramIn[idx] + lr / (std::sqrt(hi) + epsilon) * gi;
|
|
} else {
|
|
// prefetching
|
|
const int prefdist_T0 = 16;
|
|
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
|
|
: dataIndex;
|
|
std::size_t idx_pref = indices[i_pref];
|
|
|
|
if (HAS_WEIGHT_DECAY) {
|
|
g_sq_avg =
|
|
internal::compute_square_average_with_weight_decay_inlined_(
|
|
temp_grad.data(),
|
|
paramOut + offsetIdx,
|
|
block_size,
|
|
weight_decay * freq);
|
|
}
|
|
|
|
kernel(
|
|
block_size,
|
|
|
|
paramOut + offsetIdx,
|
|
¶mOut[idx_pref * block_size],
|
|
|
|
temp_grad.data(),
|
|
g_sq_avg *
|
|
(HAS_WEIGHT_DECAY
|
|
? 1
|
|
: auxParamIn[localOffset] * auxParamIn[localOffset]),
|
|
|
|
momentOut + idx,
|
|
momentOut + idx_pref,
|
|
|
|
epsilon,
|
|
lr,
|
|
HAS_WEIGHT_DECAY ? weight_decay * freq : 0.0f);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename SIndex>
|
|
static void compute(
|
|
int64_t block_size,
|
|
const SIndex* indices,
|
|
int64_t n,
|
|
const TLengths* lengths,
|
|
int64_t numSegments,
|
|
const T* gradIn,
|
|
const Tdata* paramIn,
|
|
int64_t numParams,
|
|
const T* momentIn,
|
|
const double* count,
|
|
const T* auxParamIn,
|
|
Tdata* paramOut,
|
|
T* momentOut,
|
|
T* auxGrad,
|
|
float epsilon,
|
|
T lr,
|
|
T weight_decay,
|
|
T counter_halflife,
|
|
rowWiseAdagradT& kernel,
|
|
CPUContext* context) {
|
|
if (weight_decay == 0.0f) {
|
|
compute<SIndex, /*HAS_WEIGHT_DECAY=*/false>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
auxParamIn,
|
|
paramOut,
|
|
momentOut,
|
|
auxGrad,
|
|
epsilon,
|
|
lr,
|
|
0.0f,
|
|
/*counter_halflife=*/-1,
|
|
kernel,
|
|
context);
|
|
} else {
|
|
compute<SIndex, /*HAS_WEIGHT_DECAY=*/true>(
|
|
block_size,
|
|
indices,
|
|
n,
|
|
lengths,
|
|
numSegments,
|
|
gradIn,
|
|
paramIn,
|
|
numParams,
|
|
momentIn,
|
|
count,
|
|
auxParamIn,
|
|
paramOut,
|
|
momentOut,
|
|
auxGrad,
|
|
epsilon,
|
|
lr,
|
|
weight_decay,
|
|
counter_halflife,
|
|
kernel,
|
|
context);
|
|
}
|
|
}
|
|
|
|
protected:
|
|
T epsilon_;
|
|
T weight_decay_;
|
|
T counter_halflife_;
|
|
rowWiseAdagradT kernel_;
|
|
|
|
INPUT_TAGS(PARAM, MOMENT_1, AUX_PARAM, INDICES, GRAD, LR, LENGTHS, COUNTER);
|
|
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, AUX_GRAD);
|
|
};
|
|
|
|
template <
|
|
typename Tdata, // embedding types
|
|
typename T, // everything else
|
|
typename TLengths,
|
|
typename rowWiseAdagradT>
|
|
class RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp
|
|
final : public Operator<CPUContext> {
|
|
public:
|
|
RowWiseSparseAdagradFusedWithSparseLengthsWeightedSumGradientApproxOp(
|
|
const OperatorDef& operator_def,
|
|
Workspace* ws)
|
|
: Operator<CPUContext>(operator_def, ws),
|
|
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5)),
|
|
weight_decay_(
|
|
this->template GetSingleArgument<float>("weight_decay", 0.f)),
|
|
counter_halflife_(
|
|
this->template GetSingleArgument<int64_t>("counter_halflife", -1)) {
|
|
VLOG(1) << "gradient optimization operator in use: "
|
|
<< " weight_decay_=" << weight_decay_
|
|
<< " counter_halflife=" << counter_halflife_
|
|
<< " RowWiseSparseAdagradFusedWithSparseLengthsSumGradientOp bcyuan";
|
|
const T decay = this->template GetSingleArgument<T>("decay", 1.0);
|
|
CAFFE_ENFORCE_EQ(
|
|
decay, 1.0, "Decay is not supported for SparseSimdAdagradOp");
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
// Enforce shapes
|
|
CAFFE_ENFORCE_EQ(Input(PARAM).sizes()[0], Input(MOMENT_1).numel());
|
|
CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).size_from_dim(1),
|
|
Input(GRAD).size_from_dim(Input(INDICES).dim()));
|
|
|
|
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
|
this, Input(INDICES));
|
|
}
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType() {
|
|
if (weight_decay_ == 0.0f) {
|
|
return DoRunWithType<SIndex, false>();
|
|
} else {
|
|
return DoRunWithType<SIndex, true>();
|
|
}
|
|
}
|
|
|
|
template <typename SIndex, bool HAS_WEIGHT_DECAY>
|
|
bool DoRunWithType() {
|
|
const auto* lr = Input(LR).template data<T>();
|
|
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
|
|
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
|
|
|
|
auto& segmentGradsInput = Input(GRAD);
|
|
auto& lengthsInput = Input(LENGTHS);
|
|
|
|
CAFFE_ENFORCE_EQ(lengthsInput.dim(), 1, "LENGTHS must be a vector");
|
|
auto numSegments = lengthsInput.size(0);
|
|
CAFFE_ENFORCE_GT(segmentGradsInput.dim(), 0);
|
|
CAFFE_ENFORCE_EQ(numSegments, segmentGradsInput.size(0));
|
|
const auto* lengths = lengthsInput.template data<TLengths>();
|
|
|
|
auto n = Input(INDICES).numel();
|
|
|
|
const auto* indices = Input(INDICES).template data<SIndex>();
|
|
const auto* gradIn = segmentGradsInput.template data<T>();
|
|
const auto* paramIn = Input(PARAM).template data<Tdata>();
|
|
const auto* momentIn = Input(MOMENT_1).template data<T>();
|
|
const auto* count = counter_halflife_ == -1
|
|
? nullptr
|
|
: Input(COUNTER).template data<double>();
|
|
const auto* auxParamIn = Input(AUX_PARAM).template data<T>();
|
|
|
|
auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<Tdata>();
|
|
auto* momentOut = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
|
|
Output(AUX_GRAD)->Resize(n);
|
|
auto* auxGrad = Output(AUX_GRAD)->template mutable_data<T>();
|
|
|
|
CAFFE_ENFORCE_EQ(
|
|
paramIn, paramOut, "RowWiseSparseAdagrad must use inplace param");
|
|
CAFFE_ENFORCE_EQ(
|
|
momentIn, momentOut, "RowWiseSparseAdagrad must use inplace momentum");
|
|
|
|
if (numSegments == 0) {
|
|
return true;
|
|
}
|
|
|
|
auto block_size = segmentGradsInput.size_from_dim(1);
|
|
|
|
// Enforce:
|
|
// Input(embedding/momentum) == outputs(embedding/momentum)
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(PARAM).numel() / block_size,
|
|
Input(MOMENT_1).numel(),
|
|
"Input Param size: ",
|
|
Input(PARAM).numel(),
|
|
" Block size: ",
|
|
block_size,
|
|
" Input Moment size: ",
|
|
Input(MOMENT_1).numel());
|
|
|
|
std::vector<T> temp_grad(block_size);
|
|
int dataIndex = 0;
|
|
for (const auto rangeIndex : c10::irange(numSegments)) {
|
|
auto offsetI = rangeIndex * block_size;
|
|
const float* g = gradIn + offsetI;
|
|
|
|
float g_sq_avg;
|
|
if (block_size > 1 && !HAS_WEIGHT_DECAY) {
|
|
g_sq_avg = internal::compute_square_average_inlined_(g, block_size);
|
|
}
|
|
|
|
for (auto start = dataIndex; dataIndex < start + lengths[rangeIndex];
|
|
++dataIndex) {
|
|
std::size_t idx = indices[dataIndex];
|
|
auto offsetIdx = idx * block_size;
|
|
auto localOffset = dataIndex - start;
|
|
|
|
// Enforce:
|
|
// access within range
|
|
// gradient access within range
|
|
CAFFE_ENFORCE_GE(
|
|
Input(PARAM).numel(),
|
|
block_size + offsetIdx,
|
|
this->debug_def().input(PARAM),
|
|
", out of bound, idx:",
|
|
idx,
|
|
" for input dataIndex:",
|
|
dataIndex,
|
|
" and block size:",
|
|
block_size,
|
|
" max size:",
|
|
Input(PARAM).numel());
|
|
|
|
int i = 0;
|
|
float acc = 0.0f;
|
|
|
|
#ifdef __AVX__
|
|
constexpr int VLEN = 8;
|
|
__m256 acc_v = _mm256_setzero_ps();
|
|
__m256 scalar_v = _mm256_set1_ps(auxParamIn[localOffset]);
|
|
|
|
if (std::is_same<Tdata, float>::value) {
|
|
for (; i < block_size / VLEN * VLEN; i += VLEN) {
|
|
__m256 a_v = _mm256_loadu_ps(g + i);
|
|
__m256 b_v = _mm256_loadu_ps(
|
|
reinterpret_cast<const float*>(paramIn + offsetIdx + i));
|
|
__m256 c_v = _mm256_mul_ps(a_v, b_v);
|
|
acc_v = _mm256_add_ps(acc_v, c_v);
|
|
_mm256_storeu_ps(&temp_grad[i], _mm256_mul_ps(a_v, scalar_v));
|
|
}
|
|
} else if (std::is_same<Tdata, at::Half>::value) {
|
|
for (; i < block_size / VLEN * VLEN; i += VLEN) {
|
|
__m256 a_v = _mm256_loadu_ps(g + i);
|
|
__m256 b_v = _mm256_cvtph_ps(
|
|
_mm_load_si128((__m128i*)(paramIn + offsetIdx + i)));
|
|
__m256 c_v = _mm256_mul_ps(a_v, b_v);
|
|
acc_v = _mm256_add_ps(acc_v, c_v);
|
|
_mm256_storeu_ps(&temp_grad[i], _mm256_mul_ps(a_v, scalar_v));
|
|
}
|
|
} else {
|
|
CAFFE_THROW("Unsupported type for Embedding");
|
|
}
|
|
|
|
alignas(64) float temp[VLEN];
|
|
_mm256_store_ps(temp, acc_v);
|
|
for (const auto j : c10::irange(VLEN)) {
|
|
acc += temp[j];
|
|
}
|
|
#endif
|
|
|
|
for (; i < block_size; ++i) {
|
|
float a = g[i];
|
|
acc += a * paramIn[offsetIdx + i];
|
|
temp_grad[i] = a * auxParamIn[localOffset];
|
|
}
|
|
auxGrad[dataIndex] = acc;
|
|
|
|
float freq = (counter_halflife_ > 0 && count[idx] > 0)
|
|
? counter_halflife_ / count[idx]
|
|
: 1.0;
|
|
|
|
if (block_size == 1) {
|
|
float gi = std::fma(weight_decay_ * freq, paramIn[idx], temp_grad[0]);
|
|
float hi = momentOut[idx] = momentIn[idx] + gi * gi;
|
|
paramOut[idx] =
|
|
paramIn[idx] + lr[0] / (std::sqrt(hi) + epsilon_) * gi;
|
|
} else {
|
|
// prefetching
|
|
const int prefdist_T0 = 16;
|
|
int i_pref = (dataIndex < n - prefdist_T0) ? dataIndex + prefdist_T0
|
|
: dataIndex;
|
|
std::size_t idx_pref = indices[i_pref];
|
|
|
|
if (HAS_WEIGHT_DECAY) {
|
|
g_sq_avg =
|
|
internal::compute_square_average_with_weight_decay_inlined_(
|
|
temp_grad.data(),
|
|
paramOut + offsetIdx,
|
|
block_size,
|
|
weight_decay_ * freq);
|
|
}
|
|
|
|
kernel_(
|
|
block_size,
|
|
|
|
paramOut + offsetIdx,
|
|
¶mOut[idx_pref * block_size],
|
|
|
|
temp_grad.data(),
|
|
g_sq_avg *
|
|
(HAS_WEIGHT_DECAY
|
|
? 1
|
|
: auxParamIn[localOffset] * auxParamIn[localOffset]),
|
|
|
|
momentOut + idx,
|
|
momentOut + idx_pref,
|
|
|
|
epsilon_,
|
|
lr[0],
|
|
HAS_WEIGHT_DECAY ? weight_decay_ * freq : 0.0f);
|
|
}
|
|
}
|
|
}
|
|
CAFFE_ENFORCE_EQ(dataIndex, n);
|
|
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
T epsilon_;
|
|
T weight_decay_;
|
|
T counter_halflife_;
|
|
rowWiseAdagradT kernel_;
|
|
|
|
INPUT_TAGS(PARAM, MOMENT_1, AUX_PARAM, INDICES, GRAD, LR, LENGTHS, COUNTER);
|
|
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, AUX_GRAD);
|
|
};
|
|
|
|
struct rowwise_adagrad_update_inlined {
|
|
void operator()(
|
|
int N,
|
|
float* w,
|
|
float* w_n, // prefetch ptr
|
|
const float* g,
|
|
float g_sq_avg,
|
|
float* h,
|
|
float* h_n, // prefetch ptr
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay) {
|
|
#ifdef __AVX__
|
|
constexpr int kSize = 8;
|
|
_mm_prefetch(reinterpret_cast<const char*>(h_n), _MM_HINT_T0);
|
|
#endif
|
|
float hi = *h = *h + g_sq_avg;
|
|
float float_step = lr / (std::sqrt(hi) + epsilon);
|
|
|
|
int i = 0;
|
|
|
|
#ifdef __AVX__
|
|
__m256 step = _mm256_set1_ps(float_step);
|
|
__m256 weight_decay_v = _mm256_set1_ps(weight_decay);
|
|
|
|
for (i = 0; i + kSize <= N; i += kSize) {
|
|
_mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
|
|
|
|
__m256 gi = _mm256_loadu_ps(g + i);
|
|
__m256 wi = _mm256_loadu_ps(w + i);
|
|
if (weight_decay != 0.0f) {
|
|
#ifdef __FMA__
|
|
gi = _mm256_fmadd_ps(weight_decay_v, wi, gi);
|
|
#else
|
|
gi = _mm256_add_ps(_mm256_mul_ps(weight_decay_v, wi), gi);
|
|
#endif
|
|
}
|
|
|
|
_mm256_storeu_ps(w + i, _mm256_add_ps(wi, _mm256_mul_ps(gi, step)));
|
|
}
|
|
#endif
|
|
|
|
for (; i < N; ++i) {
|
|
float gi =
|
|
weight_decay != 0.0f ? std::fma(weight_decay, w[i], g[i]) : g[i];
|
|
w[i] = w[i] + gi * float_step;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|