284 lines
13 KiB
C++
284 lines
13 KiB
C++
#ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
|
|
#define CAFFE2_SGD_LEARNING_RATE_OP_H_
|
|
|
|
#include <cfloat>
|
|
#include <cmath>
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
|
#include <c10/util/irange.h>
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/sgd/learning_rate_functors.h"
|
|
|
|
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LearningRate);
|
|
|
|
namespace caffe2 {
|
|
|
|
template <typename T, class Context>
|
|
class LearningRateOp final : public Operator<Context> {
|
|
public:
|
|
template <class... Args>
|
|
LearningRateOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
functor_(nullptr),
|
|
base_lr_(this->template GetSingleArgument<float>("base_lr", FLT_MAX)) {
|
|
CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
|
|
const string policy =
|
|
this->template GetSingleArgument<string>("policy", "");
|
|
CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
|
|
functor_.reset(createLearningRateFunctor(policy));
|
|
}
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
bool RunOnDevice() override {
|
|
int64_t iter =
|
|
OperatorBase::Input<Tensor>(0, CPU).template data<int64_t>()[0];
|
|
T learning_rate = base_lr_ * (*functor_)(iter);
|
|
// Write to output.
|
|
auto* output = Output(0);
|
|
output->Resize(vector<int64_t>());
|
|
context_.template CopyFromCPU<T>(
|
|
1, &learning_rate, Output(0)->template mutable_data<T>());
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
unique_ptr<LearningRateFunctor<T>> functor_;
|
|
T base_lr_;
|
|
|
|
LearningRateFunctor<T>* createLearningRateFunctor(
|
|
const string& policy,
|
|
const string& arg_prefix = "") {
|
|
if (policy == "fixed") {
|
|
return new FixedLearningRate<T>();
|
|
} else if (policy == "alter") {
|
|
bool active_first = this->template GetSingleArgument<bool>(
|
|
arg_prefix + "active_first", true);
|
|
int64_t active_period = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "active_period", -1);
|
|
int64_t inactive_period = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "inactive_period", -1);
|
|
TORCH_DCHECK_GE(active_period, 0);
|
|
TORCH_DCHECK_GE(inactive_period, 0);
|
|
return new AlternateLearningRate<T>(
|
|
active_period, inactive_period, active_first);
|
|
} else if (policy == "hill") {
|
|
int64_t num_iter =
|
|
this->template GetSingleArgument<int64_t>(arg_prefix + "num_iter", 0);
|
|
TORCH_DCHECK_GT(num_iter, 0);
|
|
T start_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "start_multiplier", 0.);
|
|
TORCH_DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
|
|
TORCH_DCHECK_LE(start_multiplier, 1);
|
|
T gamma =
|
|
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
|
|
TORCH_DCHECK_GT(gamma, 0);
|
|
T power =
|
|
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
|
|
TORCH_DCHECK_GT(power, 0);
|
|
T end_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "end_multiplier", 0);
|
|
TORCH_DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
|
|
TORCH_DCHECK_LE(end_multiplier, 1);
|
|
return new HillLearningRate<T>(
|
|
num_iter, start_multiplier, gamma, power, end_multiplier);
|
|
} else if (policy == "slope") {
|
|
int64_t num_iter_1 = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "num_iter_1", 0);
|
|
TORCH_DCHECK_GT(num_iter_1, 0);
|
|
T multiplier_1 = this->template GetSingleArgument<float>(
|
|
arg_prefix + "multiplier_1", 0.);
|
|
int64_t num_iter_2 = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "num_iter_2", 0);
|
|
TORCH_DCHECK_GT(num_iter_1, 0);
|
|
T multiplier_2 = this->template GetSingleArgument<float>(
|
|
arg_prefix + "multiplier_2", 0.);
|
|
TORCH_DCHECK_GT(num_iter_2, num_iter_1);
|
|
return new SlopeLearningRate<T>(
|
|
num_iter_1, multiplier_1, num_iter_2, multiplier_2);
|
|
} else if (policy == "step") {
|
|
int stepsize =
|
|
this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
|
|
T gamma =
|
|
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
|
|
TORCH_DCHECK_GT(stepsize, 0);
|
|
TORCH_DCHECK_GT(gamma, 0);
|
|
return new StepLearningRate<T>(stepsize, gamma);
|
|
} else if (policy == "exp") {
|
|
T gamma =
|
|
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
|
|
TORCH_DCHECK_GT(gamma, 0);
|
|
return new ExpLearningRate<T>(gamma);
|
|
} else if (policy == "gate") {
|
|
T multiplier_1 = this->template GetSingleArgument<float>(
|
|
arg_prefix + "multiplier_1", 1);
|
|
T multiplier_2 = this->template GetSingleArgument<float>(
|
|
arg_prefix + "multiplier_2", 1);
|
|
int num_iter =
|
|
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
|
|
// no constraint on the range of multiplier_1 and multiplier_2
|
|
return new GateLearningRate<T>(multiplier_1, multiplier_2, num_iter);
|
|
} else if (policy == "inv") {
|
|
T gamma =
|
|
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
|
|
T power =
|
|
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
|
|
TORCH_DCHECK_GT(gamma, 0);
|
|
TORCH_DCHECK_GT(power, 0);
|
|
return new InvLearningRate<T>(gamma, power);
|
|
} else if (policy == "poly") {
|
|
int max_iter =
|
|
this->template GetSingleArgument<int>(arg_prefix + "max_iter", -1);
|
|
T power =
|
|
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
|
|
TORCH_DCHECK_GT(power, 0);
|
|
return new PolyLearningRate<T>(power, max_iter);
|
|
} else if (policy == "linearWarmup") {
|
|
T start_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "start_multiplier", 0.);
|
|
int num_iter =
|
|
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
|
|
TORCH_DCHECK_GE(start_multiplier, 0);
|
|
return new LinearWarmupLearningRate<T>(start_multiplier, num_iter);
|
|
} else if (policy == "constantWarmup") {
|
|
T multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "multiplier", 0.5);
|
|
int num_iter =
|
|
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
|
|
TORCH_DCHECK_GT(multiplier, 0);
|
|
return new ConstantWarmupLearningRate<T>(multiplier, num_iter);
|
|
} else if (policy == "pieceWarmup") {
|
|
T m1 = this->template GetSingleArgument<float>(arg_prefix + "m1", 0.5);
|
|
int64_t n1 =
|
|
this->template GetSingleArgument<int64_t>(arg_prefix + "n1", 0);
|
|
T m2 = this->template GetSingleArgument<float>(arg_prefix + "m2", 0.5);
|
|
int64_t n2 =
|
|
this->template GetSingleArgument<int64_t>(arg_prefix + "n2", 0);
|
|
T m3 = this->template GetSingleArgument<float>(arg_prefix + "m3", 0.5);
|
|
return new PieceWarmupLearningRate<T>(m1, n1, m2, n2, m3);
|
|
} else if (policy == "composite") {
|
|
std::vector<int> sub_policy_num_iters =
|
|
this->template GetRepeatedArgument<int>("sub_policy_num_iters");
|
|
std::list<CompositeLearningRateItem<T>> sub_policies;
|
|
CAFFE_ENFORCE_GT(
|
|
sub_policy_num_iters.size(),
|
|
0,
|
|
"Must specify at least one sub learning rate policy.");
|
|
for (const auto i : c10::irange(sub_policy_num_iters.size())) {
|
|
CAFFE_ENFORCE_GT(
|
|
sub_policy_num_iters[i],
|
|
0,
|
|
"The number of iterations for sub learning rate policy should be positive.");
|
|
std::stringstream sub_policy_arg_prefix;
|
|
sub_policy_arg_prefix << "sub_policy_" << i << "_";
|
|
const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
|
|
const string sub_policy = this->template GetSingleArgument<string>(
|
|
sub_policy_arg_prefix_str + "policy", "");
|
|
if (sub_policy == "composite") {
|
|
CAFFE_THROW(
|
|
"Defining composite LR policy as a subpolicy of composite LR "
|
|
"policy is not allowed.");
|
|
}
|
|
const float scale_lr = this->template GetSingleArgument<float>(
|
|
sub_policy_arg_prefix_str + "lr_scale", 1.0);
|
|
sub_policies.push_back(CompositeLearningRateItem<T>(
|
|
sub_policy_num_iters[i],
|
|
scale_lr,
|
|
createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
|
|
}
|
|
return new CompositeLearningRate<T>(sub_policies);
|
|
} else if (policy == "cyclical") {
|
|
T max_lr =
|
|
this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.005);
|
|
int stepsize =
|
|
this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
|
|
T decay =
|
|
this->template GetSingleArgument<float>(arg_prefix + "decay", 1.0);
|
|
TORCH_DCHECK_GT(stepsize, 0);
|
|
TORCH_DCHECK_GE(max_lr, base_lr_);
|
|
return new CyclicalLearningRate<T>(base_lr_, max_lr, stepsize, decay);
|
|
} else if (policy == "constantThenLinearWarmup") {
|
|
T start_warmup_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "start_warmup_multiplier", 0.1);
|
|
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "constant_warmup_num_iter", 10000000);
|
|
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "linear_warmup_num_iter", 10000000);
|
|
return new ConstantThenLinearWarmupLearningRate<T>(
|
|
start_warmup_multiplier,
|
|
constant_warmup_num_iter,
|
|
linear_warmup_num_iter);
|
|
} else if (policy == "compositeCyclical") {
|
|
T start_warmup_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "start_warmup_multiplier", 0.1);
|
|
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "constant_warmup_num_iter", 10000000);
|
|
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "linear_warmup_num_iter", 10000000);
|
|
T cyclical_max_lr = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cyclical_max_lr", 0.05);
|
|
int cyclical_step_size = this->template GetSingleArgument<int>(
|
|
arg_prefix + "cyclical_step_size", 1000000);
|
|
T cyclical_decay = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cyclical_decay", 1.0);
|
|
TORCH_DCHECK_GE(cyclical_max_lr, base_lr_);
|
|
return new CompositeCyclicalLearningRate<T>(
|
|
base_lr_,
|
|
start_warmup_multiplier,
|
|
constant_warmup_num_iter,
|
|
linear_warmup_num_iter,
|
|
cyclical_max_lr,
|
|
cyclical_step_size,
|
|
cyclical_decay);
|
|
} else if (policy == "cosine") {
|
|
T max_lr =
|
|
this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.5);
|
|
T min_lr =
|
|
this->template GetSingleArgument<float>(arg_prefix + "min_lr", 0.1);
|
|
int64_t period =
|
|
this->template GetSingleArgument<int>(arg_prefix + "period", 50);
|
|
T t_mult =
|
|
this->template GetSingleArgument<float>(arg_prefix + "t_mult", 1.0);
|
|
T lr_shrink = this->template GetSingleArgument<float>(
|
|
arg_prefix + "lr_shrink", 0.99);
|
|
TORCH_DCHECK_GE(max_lr, min_lr);
|
|
return new CosineLearningRate<T>(
|
|
min_lr, max_lr, period, t_mult, lr_shrink);
|
|
} else if (policy == "compositeCosine") {
|
|
T start_warmup_multiplier = this->template GetSingleArgument<float>(
|
|
arg_prefix + "start_warmup_multiplier", 0.1);
|
|
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "constant_warmup_num_iter", 10000000);
|
|
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
|
|
arg_prefix + "linear_warmup_num_iter", 10000000);
|
|
T cosine_max_lr = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cosine_max_lr", 0.5);
|
|
T cosine_min_lr = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cosine_min_lr", 0.1);
|
|
int64_t cosine_period = this->template GetSingleArgument<int>(
|
|
arg_prefix + "cosine_period", 50);
|
|
T cosine_t_mult = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cosine_t_mult", 1.0);
|
|
T cosine_lr_shrink = this->template GetSingleArgument<float>(
|
|
arg_prefix + "cosine_lr_shrink", 0.99);
|
|
|
|
TORCH_DCHECK_GE(cosine_max_lr, cosine_min_lr);
|
|
return new CompositeCosineLearningRate<T>(
|
|
start_warmup_multiplier,
|
|
constant_warmup_num_iter,
|
|
linear_warmup_num_iter,
|
|
cosine_min_lr,
|
|
cosine_max_lr,
|
|
cosine_period,
|
|
cosine_t_mult,
|
|
cosine_lr_shrink);
|
|
} else {
|
|
CAFFE_THROW("Unknown learning rate policy: ", policy);
|
|
return NULL;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_
|