148 lines
3.9 KiB
C++
148 lines
3.9 KiB
C++
#include "ftrl_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class T>
|
|
inline T sgn(const T x) {
|
|
return (x == 0 ? 0 : (x < 0 ? -1 : 1));
|
|
}
|
|
|
|
template <typename T>
|
|
inline void ftrl_compute(
|
|
const T w,
|
|
const T n,
|
|
const T z,
|
|
const T g,
|
|
T& nw,
|
|
T& nn,
|
|
T& nz,
|
|
const FtrlParams<T>& params) {
|
|
auto new_n = n + g * g;
|
|
auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
|
|
nn = new_n;
|
|
nz = z + g - sigma * w;
|
|
// update the weight
|
|
if (std::abs(nz) > params.lambda1) {
|
|
nw = (params.lambda1 * sgn(nz) - nz) /
|
|
((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
|
|
} else {
|
|
nw = 0.0;
|
|
}
|
|
}
|
|
|
|
// TODO(dzhulgakov): implement SIMD-based version
|
|
template <typename Context, typename T>
|
|
void ftrl_update(
|
|
int N,
|
|
const T* w,
|
|
const T* nz,
|
|
const T* g,
|
|
T* new_w,
|
|
T* new_nz,
|
|
const FtrlParams<T>& params,
|
|
Context* /*context*/) {
|
|
// TODO(cxj): use OMP when it is reliable
|
|
// #pragma omp parallel for
|
|
for (auto i = 0; i < N; ++i) {
|
|
ftrl_compute(
|
|
w[i],
|
|
nz[i * 2],
|
|
nz[i * 2 + 1],
|
|
g[i],
|
|
new_w[i],
|
|
new_nz[i * 2],
|
|
new_nz[i * 2 + 1],
|
|
params);
|
|
}
|
|
}
|
|
|
|
template <typename T, typename Context>
|
|
bool FtrlOp<T, Context>::RunOnDevice() {
|
|
// run time learning rate override
|
|
if (ALPHA < InputSize()) {
|
|
CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
|
|
params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
|
|
}
|
|
CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
|
|
CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
|
|
Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
|
|
Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
|
|
ftrl_update<Context>(
|
|
Input(GRAD).numel(),
|
|
Input(VAR).template data<T>(),
|
|
Input(N_Z).template data<T>(),
|
|
Input(GRAD).template data<T>(),
|
|
Output(OUTPUT_VAR)->template mutable_data<T>(),
|
|
Output(OUTPUT_N_Z)->template mutable_data<T>(),
|
|
params_,
|
|
&context_);
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
template <typename SIndex>
|
|
void SparseFtrlOp<T>::DoRun() {
|
|
auto* var = Output(OUTPUT_VAR);
|
|
auto* n_z = Output(OUTPUT_N_Z);
|
|
auto& indices = Input(INDICES);
|
|
auto& grad = Input(GRAD);
|
|
CAFFE_ENFORCE_EQ(&Input(VAR), var, "In place operation is required");
|
|
CAFFE_ENFORCE_EQ(&Input(N_Z), n_z, "In place operation is required");
|
|
int64_t M = var->numel();
|
|
int64_t N = var->size(0);
|
|
int64_t block_size = M / N;
|
|
int64_t K = indices.numel();
|
|
TORCH_DCHECK_EQ(M * 2, n_z->numel());
|
|
TORCH_DCHECK_EQ(grad.numel(), K * block_size);
|
|
T* w = var->template mutable_data<T>();
|
|
T* nz = n_z->template mutable_data<T>();
|
|
const SIndex* idxs = indices.template data<SIndex>();
|
|
const T* g = grad.template data<T>();
|
|
|
|
// TODO(cxj): use OMP when it is reliable
|
|
// #pragma omp parallel for
|
|
for (int64_t i = 0; i < K; ++i) {
|
|
SIndex idx = idxs[i];
|
|
DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
|
|
<< ", range 0 to " << N;
|
|
if (block_size == 1) {
|
|
ftrl_compute(
|
|
w[idx],
|
|
nz[idx * 2],
|
|
nz[idx * 2 + 1],
|
|
g[i],
|
|
w[idx],
|
|
nz[idx * 2],
|
|
nz[idx * 2 + 1],
|
|
params_);
|
|
} else {
|
|
int64_t x = block_size * idx;
|
|
ftrl_update(
|
|
block_size,
|
|
w + x,
|
|
nz + x * 2,
|
|
g + i * block_size,
|
|
w + x,
|
|
nz + x * 2,
|
|
params_,
|
|
&context_);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>);
|
|
OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
|
|
{1, 1}});
|
|
SHOULD_NOT_DO_GRADIENT(Ftrl);
|
|
|
|
REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>);
|
|
OPERATOR_SCHEMA(SparseFtrl)
|
|
.NumInputs(4, 5)
|
|
.NumOutputs(2)
|
|
.EnforceInplace({{0, 0}, {1, 1}});
|
|
SHOULD_NOT_DO_GRADIENT(SparseFtrl);
|
|
}
|
|
|
|
}
|