292 lines
7.8 KiB
C++
292 lines
7.8 KiB
C++
#ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H_
|
|
#define CAFFE2_OPERATORS_LAYER_NORM_OP_H_
|
|
|
|
#include <array>
|
|
#include <vector>
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/types.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LayerNorm)
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class Context>
|
|
class LayerNormOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
template <class... Args>
|
|
explicit LayerNormOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
OP_SINGLE_ARG(int, "axis", axis_, 1),
|
|
OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
|
|
OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& X = Input(0);
|
|
auto* Y = Output(0);
|
|
CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >= 2.");
|
|
const int canonical_axis = X.canonical_axis_index(axis_);
|
|
std::vector<int64_t> moments_dims(
|
|
X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
|
|
moments_dims.push_back(1);
|
|
auto* mean = Output(1, moments_dims, at::dtype<T>());
|
|
auto* sigma = Output(2, moments_dims, at::dtype<T>());
|
|
const int M = X.size_to_dim(canonical_axis);
|
|
const int N = X.size_from_dim(canonical_axis);
|
|
Y->ResizeLike(X);
|
|
scale_.Resize(M);
|
|
bias_.Resize(M);
|
|
const T* X_data = X.template data<T>();
|
|
T* Y_data = Y->template mutable_data<T>();
|
|
T* mean_data = mean->template mutable_data<T>();
|
|
T* sigma_data = sigma->template mutable_data<T>();
|
|
T* scale_data = scale_.template mutable_data<T>();
|
|
T* bias_data = bias_.template mutable_data<T>();
|
|
|
|
if (M == 0) {
|
|
return true;
|
|
}
|
|
|
|
const std::array<int, 2> X_dims = {M, N};
|
|
const std::array<int, 2> Y_dims = {M, 1};
|
|
math::Moments<T, Context>(
|
|
2,
|
|
X_dims.data(),
|
|
Y_dims.data(),
|
|
X_data,
|
|
mean_data,
|
|
sigma_data,
|
|
&context_);
|
|
ComputeSigmaAndFusedParams<T>(
|
|
M, epsilon_, mean_data, sigma_data, sigma_data, scale_data, bias_data);
|
|
const T* gamma_data = nullptr;
|
|
const T* beta_data = nullptr;
|
|
if (elementwise_affine_) {
|
|
CAFFE_ENFORCE_EQ(InputSize(), 3);
|
|
const auto& gamma = Input(1);
|
|
const auto& beta = Input(2);
|
|
CAFFE_ENFORCE_EQ(gamma.numel(), N);
|
|
CAFFE_ENFORCE_EQ(beta.numel(), N);
|
|
gamma_data = gamma.template data<T>();
|
|
beta_data = beta.template data<T>();
|
|
}
|
|
LayerNormForward<T>(
|
|
M, N, X_data, scale_data, bias_data, gamma_data, beta_data, Y_data);
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
template <typename T>
|
|
void ComputeSigmaAndFusedParams(
|
|
const int N,
|
|
const float eps,
|
|
const T* mean,
|
|
const T* var,
|
|
T* stddev,
|
|
T* scale,
|
|
T* bias);
|
|
|
|
template <typename T>
|
|
void LayerNormForward(
|
|
const int M,
|
|
const int N,
|
|
const T* X,
|
|
const T* scale,
|
|
const T* bias,
|
|
const T* gamma,
|
|
const T* beta,
|
|
T* Y);
|
|
|
|
const int axis_;
|
|
const float epsilon_;
|
|
const bool elementwise_affine_;
|
|
|
|
Tensor scale_{Context::GetDeviceType()};
|
|
Tensor bias_{Context::GetDeviceType()};
|
|
};
|
|
|
|
template <class Context>
|
|
class LayerNormGradientOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
|
explicit LayerNormGradientOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
OP_SINGLE_ARG(int, "axis", axis_, 1),
|
|
OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& dY = Input(0);
|
|
const auto& mean = Input(2);
|
|
const auto& sigma = Input(3);
|
|
const auto& X = Input(4);
|
|
|
|
const int canonical_axis = X.canonical_axis_index(axis_);
|
|
const int M = X.size_to_dim(canonical_axis);
|
|
const int N = X.size_from_dim(canonical_axis);
|
|
|
|
auto* dX = Output(0, X.sizes(), at::dtype<T>());
|
|
ReinitializeTensor(
|
|
&ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
ReinitializeTensor(
|
|
&db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
ReinitializeTensor(
|
|
&rstd_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
ReinitializeTensor(
|
|
&X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
ReinitializeTensor(
|
|
&bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
const T* dY_data = dY.template data<T>();
|
|
const T* X_data = X.template data<T>();
|
|
const T* mean_data = mean.template data<T>();
|
|
const T* sigma_data = sigma.template data<T>();
|
|
T* dX_data = dX->template mutable_data<T>();
|
|
T* ds_data = ds_.template mutable_data<T>();
|
|
T* db_data = db_.template mutable_data<T>();
|
|
T* rstd_data = rstd_.template mutable_data<T>();
|
|
T* X_scale_data = X_scale_.template mutable_data<T>();
|
|
T* bias_data = bias_.template mutable_data<T>();
|
|
|
|
const T* gamma_data = nullptr;
|
|
T* dgamma_data = nullptr;
|
|
T* dbeta_data = nullptr;
|
|
T* g_scale_data = nullptr;
|
|
if (elementwise_affine_) {
|
|
const auto& gamma = Input(5);
|
|
auto* dgamma = Output(1, gamma.sizes(), at::dtype<T>());
|
|
auto* dbeta = Output(2, gamma.sizes(), at::dtype<T>());
|
|
ReinitializeTensor(
|
|
&g_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
|
gamma_data = gamma.template data<T>();
|
|
dgamma_data = dgamma->template mutable_data<T>();
|
|
dbeta_data = dbeta->template mutable_data<T>();
|
|
g_scale_data = g_scale_.template mutable_data<T>();
|
|
}
|
|
|
|
if (M == 0) {
|
|
if (N > 0 && dgamma_data != nullptr) {
|
|
math::Set<T, Context>(N, T(0), dgamma_data, &context_);
|
|
}
|
|
if (N > 0 && dbeta_data != nullptr) {
|
|
math::Set<T, Context>(N, T(0), dbeta_data, &context_);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
ComputeInternalGradients<T>(
|
|
M, N, dY_data, X_data, gamma_data, dX_data, ds_data, db_data);
|
|
ComputeFusedParams<T>(
|
|
M,
|
|
N,
|
|
mean_data,
|
|
sigma_data,
|
|
ds_data,
|
|
db_data,
|
|
rstd_data,
|
|
X_scale_data,
|
|
bias_data,
|
|
g_scale_data);
|
|
if (elementwise_affine_) {
|
|
GammaBetaBackward<T>(
|
|
M,
|
|
N,
|
|
dX_data,
|
|
dY_data,
|
|
rstd_data,
|
|
g_scale_data,
|
|
dgamma_data,
|
|
dbeta_data);
|
|
}
|
|
LayerNormBackward<T>(
|
|
M,
|
|
N,
|
|
dY_data,
|
|
X_data,
|
|
gamma_data,
|
|
rstd_data,
|
|
X_scale_data,
|
|
bias_data,
|
|
dX_data);
|
|
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
template <typename T>
|
|
void ComputeInternalGradients(
|
|
const int M,
|
|
const int N,
|
|
const T* dY,
|
|
const T* X,
|
|
const T* gamma,
|
|
T* dYxX,
|
|
T* ds,
|
|
T* db);
|
|
|
|
template <typename T>
|
|
void ComputeFusedParams(
|
|
const int M,
|
|
const int N,
|
|
const T* mean,
|
|
const T* sigma,
|
|
const T* ds,
|
|
const T* db,
|
|
T* rstd,
|
|
T* X_scale,
|
|
T* bias,
|
|
T* g_scale);
|
|
|
|
template <typename T>
|
|
void LayerNormBackward(
|
|
const int M,
|
|
const int N,
|
|
const T* dY,
|
|
const T* X,
|
|
const T* gamma,
|
|
const T* dY_scale,
|
|
const T* X_scale,
|
|
const T* bias,
|
|
T* dX);
|
|
|
|
template <typename T>
|
|
void GammaBetaBackward(
|
|
const int M,
|
|
const int N,
|
|
const T* dYxX,
|
|
const T* dY,
|
|
const T* rstd,
|
|
const T* g_scale,
|
|
T* dgamma,
|
|
T* dbeta);
|
|
|
|
const int axis_;
|
|
const bool elementwise_affine_;
|
|
|
|
Tensor ds_;
|
|
Tensor db_;
|
|
Tensor rstd_;
|
|
Tensor X_scale_;
|
|
Tensor bias_;
|
|
Tensor g_scale_;
|
|
Tensor ones_;
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_LAYER_NORM_OP_H_
|