pytorch/caffe2/sgd/ftrl_op.cc

148 lines
3.9 KiB
C++

#include "ftrl_op.h"
namespace caffe2 {
template <class T>
inline T sgn(const T x) {
return (x == 0 ? 0 : (x < 0 ? -1 : 1));
}
template <typename T>
inline void ftrl_compute(
const T w,
const T n,
const T z,
const T g,
T& nw,
T& nn,
T& nz,
const FtrlParams<T>& params) {
auto new_n = n + g * g;
auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
nn = new_n;
nz = z + g - sigma * w;
// update the weight
if (std::abs(nz) > params.lambda1) {
nw = (params.lambda1 * sgn(nz) - nz) /
((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
} else {
nw = 0.0;
}
}
// TODO(dzhulgakov): implement SIMD-based version
template <typename Context, typename T>
void ftrl_update(
int N,
const T* w,
const T* nz,
const T* g,
T* new_w,
T* new_nz,
const FtrlParams<T>& params,
Context* /*context*/) {
// TODO(cxj): use OMP when it is reliable
// #pragma omp parallel for
for (auto i = 0; i < N; ++i) {
ftrl_compute(
w[i],
nz[i * 2],
nz[i * 2 + 1],
g[i],
new_w[i],
new_nz[i * 2],
new_nz[i * 2 + 1],
params);
}
}
template <typename T, typename Context>
bool FtrlOp<T, Context>::RunOnDevice() {
// run time learning rate override
if (ALPHA < InputSize()) {
CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
}
CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
ftrl_update<Context>(
Input(GRAD).numel(),
Input(VAR).template data<T>(),
Input(N_Z).template data<T>(),
Input(GRAD).template data<T>(),
Output(OUTPUT_VAR)->template mutable_data<T>(),
Output(OUTPUT_N_Z)->template mutable_data<T>(),
params_,
&context_);
return true;
}
template <typename T>
template <typename SIndex>
void SparseFtrlOp<T>::DoRun() {
auto* var = Output(OUTPUT_VAR);
auto* n_z = Output(OUTPUT_N_Z);
auto& indices = Input(INDICES);
auto& grad = Input(GRAD);
CAFFE_ENFORCE_EQ(&Input(VAR), var, "In place operation is required");
CAFFE_ENFORCE_EQ(&Input(N_Z), n_z, "In place operation is required");
int64_t M = var->numel();
int64_t N = var->size(0);
int64_t block_size = M / N;
int64_t K = indices.numel();
TORCH_DCHECK_EQ(M * 2, n_z->numel());
TORCH_DCHECK_EQ(grad.numel(), K * block_size);
T* w = var->template mutable_data<T>();
T* nz = n_z->template mutable_data<T>();
const SIndex* idxs = indices.template data<SIndex>();
const T* g = grad.template data<T>();
// TODO(cxj): use OMP when it is reliable
// #pragma omp parallel for
for (int64_t i = 0; i < K; ++i) {
SIndex idx = idxs[i];
DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
<< ", range 0 to " << N;
if (block_size == 1) {
ftrl_compute(
w[idx],
nz[idx * 2],
nz[idx * 2 + 1],
g[i],
w[idx],
nz[idx * 2],
nz[idx * 2 + 1],
params_);
} else {
int64_t x = block_size * idx;
ftrl_update(
block_size,
w + x,
nz + x * 2,
g + i * block_size,
w + x,
nz + x * 2,
params_,
&context_);
}
}
}
namespace {
REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>);
OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
{1, 1}});
SHOULD_NOT_DO_GRADIENT(Ftrl);
REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>);
OPERATOR_SCHEMA(SparseFtrl)
.NumInputs(4, 5)
.NumOutputs(2)
.EnforceInplace({{0, 0}, {1, 1}});
SHOULD_NOT_DO_GRADIENT(SparseFtrl);
}
}