pytorch/caffe2/sgd/learning_rate_op.h

284 lines
13 KiB
C++

#ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_
#define CAFFE2_SGD_LEARNING_RATE_OP_H_
#include <cfloat>
#include <cmath>
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include <c10/util/irange.h>
#include "caffe2/core/operator.h"
#include "caffe2/sgd/learning_rate_functors.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LearningRate);
namespace caffe2 {
template <typename T, class Context>
class LearningRateOp final : public Operator<Context> {
public:
template <class... Args>
LearningRateOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
functor_(nullptr),
base_lr_(this->template GetSingleArgument<float>("base_lr", FLT_MAX)) {
CAFFE_ENFORCE_NE(base_lr_, FLT_MAX, "Base learning rate must be set.");
const string policy =
this->template GetSingleArgument<string>("policy", "");
CAFFE_ENFORCE(policy.size(), "Must specify a learning rate policy.");
functor_.reset(createLearningRateFunctor(policy));
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
int64_t iter =
OperatorBase::Input<Tensor>(0, CPU).template data<int64_t>()[0];
T learning_rate = base_lr_ * (*functor_)(iter);
// Write to output.
auto* output = Output(0);
output->Resize(vector<int64_t>());
context_.template CopyFromCPU<T>(
1, &learning_rate, Output(0)->template mutable_data<T>());
return true;
}
private:
unique_ptr<LearningRateFunctor<T>> functor_;
T base_lr_;
LearningRateFunctor<T>* createLearningRateFunctor(
const string& policy,
const string& arg_prefix = "") {
if (policy == "fixed") {
return new FixedLearningRate<T>();
} else if (policy == "alter") {
bool active_first = this->template GetSingleArgument<bool>(
arg_prefix + "active_first", true);
int64_t active_period = this->template GetSingleArgument<int64_t>(
arg_prefix + "active_period", -1);
int64_t inactive_period = this->template GetSingleArgument<int64_t>(
arg_prefix + "inactive_period", -1);
TORCH_DCHECK_GE(active_period, 0);
TORCH_DCHECK_GE(inactive_period, 0);
return new AlternateLearningRate<T>(
active_period, inactive_period, active_first);
} else if (policy == "hill") {
int64_t num_iter =
this->template GetSingleArgument<int64_t>(arg_prefix + "num_iter", 0);
TORCH_DCHECK_GT(num_iter, 0);
T start_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "start_multiplier", 0.);
TORCH_DCHECK_GE(start_multiplier, 0); // start_multiplier in range [0, 1]
TORCH_DCHECK_LE(start_multiplier, 1);
T gamma =
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
TORCH_DCHECK_GT(gamma, 0);
T power =
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
TORCH_DCHECK_GT(power, 0);
T end_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "end_multiplier", 0);
TORCH_DCHECK_GE(end_multiplier, 0); // end_multiplier in range [0, 1]
TORCH_DCHECK_LE(end_multiplier, 1);
return new HillLearningRate<T>(
num_iter, start_multiplier, gamma, power, end_multiplier);
} else if (policy == "slope") {
int64_t num_iter_1 = this->template GetSingleArgument<int64_t>(
arg_prefix + "num_iter_1", 0);
TORCH_DCHECK_GT(num_iter_1, 0);
T multiplier_1 = this->template GetSingleArgument<float>(
arg_prefix + "multiplier_1", 0.);
int64_t num_iter_2 = this->template GetSingleArgument<int64_t>(
arg_prefix + "num_iter_2", 0);
TORCH_DCHECK_GT(num_iter_1, 0);
T multiplier_2 = this->template GetSingleArgument<float>(
arg_prefix + "multiplier_2", 0.);
TORCH_DCHECK_GT(num_iter_2, num_iter_1);
return new SlopeLearningRate<T>(
num_iter_1, multiplier_1, num_iter_2, multiplier_2);
} else if (policy == "step") {
int stepsize =
this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
T gamma =
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
TORCH_DCHECK_GT(stepsize, 0);
TORCH_DCHECK_GT(gamma, 0);
return new StepLearningRate<T>(stepsize, gamma);
} else if (policy == "exp") {
T gamma =
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
TORCH_DCHECK_GT(gamma, 0);
return new ExpLearningRate<T>(gamma);
} else if (policy == "gate") {
T multiplier_1 = this->template GetSingleArgument<float>(
arg_prefix + "multiplier_1", 1);
T multiplier_2 = this->template GetSingleArgument<float>(
arg_prefix + "multiplier_2", 1);
int num_iter =
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
// no constraint on the range of multiplier_1 and multiplier_2
return new GateLearningRate<T>(multiplier_1, multiplier_2, num_iter);
} else if (policy == "inv") {
T gamma =
this->template GetSingleArgument<float>(arg_prefix + "gamma", 0);
T power =
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
TORCH_DCHECK_GT(gamma, 0);
TORCH_DCHECK_GT(power, 0);
return new InvLearningRate<T>(gamma, power);
} else if (policy == "poly") {
int max_iter =
this->template GetSingleArgument<int>(arg_prefix + "max_iter", -1);
T power =
this->template GetSingleArgument<float>(arg_prefix + "power", 0);
TORCH_DCHECK_GT(power, 0);
return new PolyLearningRate<T>(power, max_iter);
} else if (policy == "linearWarmup") {
T start_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "start_multiplier", 0.);
int num_iter =
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
TORCH_DCHECK_GE(start_multiplier, 0);
return new LinearWarmupLearningRate<T>(start_multiplier, num_iter);
} else if (policy == "constantWarmup") {
T multiplier = this->template GetSingleArgument<float>(
arg_prefix + "multiplier", 0.5);
int num_iter =
this->template GetSingleArgument<int>(arg_prefix + "num_iter", 0);
TORCH_DCHECK_GT(multiplier, 0);
return new ConstantWarmupLearningRate<T>(multiplier, num_iter);
} else if (policy == "pieceWarmup") {
T m1 = this->template GetSingleArgument<float>(arg_prefix + "m1", 0.5);
int64_t n1 =
this->template GetSingleArgument<int64_t>(arg_prefix + "n1", 0);
T m2 = this->template GetSingleArgument<float>(arg_prefix + "m2", 0.5);
int64_t n2 =
this->template GetSingleArgument<int64_t>(arg_prefix + "n2", 0);
T m3 = this->template GetSingleArgument<float>(arg_prefix + "m3", 0.5);
return new PieceWarmupLearningRate<T>(m1, n1, m2, n2, m3);
} else if (policy == "composite") {
std::vector<int> sub_policy_num_iters =
this->template GetRepeatedArgument<int>("sub_policy_num_iters");
std::list<CompositeLearningRateItem<T>> sub_policies;
CAFFE_ENFORCE_GT(
sub_policy_num_iters.size(),
0,
"Must specify at least one sub learning rate policy.");
for (const auto i : c10::irange(sub_policy_num_iters.size())) {
CAFFE_ENFORCE_GT(
sub_policy_num_iters[i],
0,
"The number of iterations for sub learning rate policy should be positive.");
std::stringstream sub_policy_arg_prefix;
sub_policy_arg_prefix << "sub_policy_" << i << "_";
const string sub_policy_arg_prefix_str = sub_policy_arg_prefix.str();
const string sub_policy = this->template GetSingleArgument<string>(
sub_policy_arg_prefix_str + "policy", "");
if (sub_policy == "composite") {
CAFFE_THROW(
"Defining composite LR policy as a subpolicy of composite LR "
"policy is not allowed.");
}
const float scale_lr = this->template GetSingleArgument<float>(
sub_policy_arg_prefix_str + "lr_scale", 1.0);
sub_policies.push_back(CompositeLearningRateItem<T>(
sub_policy_num_iters[i],
scale_lr,
createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str)));
}
return new CompositeLearningRate<T>(sub_policies);
} else if (policy == "cyclical") {
T max_lr =
this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.005);
int stepsize =
this->template GetSingleArgument<int>(arg_prefix + "stepsize", 0);
T decay =
this->template GetSingleArgument<float>(arg_prefix + "decay", 1.0);
TORCH_DCHECK_GT(stepsize, 0);
TORCH_DCHECK_GE(max_lr, base_lr_);
return new CyclicalLearningRate<T>(base_lr_, max_lr, stepsize, decay);
} else if (policy == "constantThenLinearWarmup") {
T start_warmup_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "start_warmup_multiplier", 0.1);
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "constant_warmup_num_iter", 10000000);
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "linear_warmup_num_iter", 10000000);
return new ConstantThenLinearWarmupLearningRate<T>(
start_warmup_multiplier,
constant_warmup_num_iter,
linear_warmup_num_iter);
} else if (policy == "compositeCyclical") {
T start_warmup_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "start_warmup_multiplier", 0.1);
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "constant_warmup_num_iter", 10000000);
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "linear_warmup_num_iter", 10000000);
T cyclical_max_lr = this->template GetSingleArgument<float>(
arg_prefix + "cyclical_max_lr", 0.05);
int cyclical_step_size = this->template GetSingleArgument<int>(
arg_prefix + "cyclical_step_size", 1000000);
T cyclical_decay = this->template GetSingleArgument<float>(
arg_prefix + "cyclical_decay", 1.0);
TORCH_DCHECK_GE(cyclical_max_lr, base_lr_);
return new CompositeCyclicalLearningRate<T>(
base_lr_,
start_warmup_multiplier,
constant_warmup_num_iter,
linear_warmup_num_iter,
cyclical_max_lr,
cyclical_step_size,
cyclical_decay);
} else if (policy == "cosine") {
T max_lr =
this->template GetSingleArgument<float>(arg_prefix + "max_lr", 0.5);
T min_lr =
this->template GetSingleArgument<float>(arg_prefix + "min_lr", 0.1);
int64_t period =
this->template GetSingleArgument<int>(arg_prefix + "period", 50);
T t_mult =
this->template GetSingleArgument<float>(arg_prefix + "t_mult", 1.0);
T lr_shrink = this->template GetSingleArgument<float>(
arg_prefix + "lr_shrink", 0.99);
TORCH_DCHECK_GE(max_lr, min_lr);
return new CosineLearningRate<T>(
min_lr, max_lr, period, t_mult, lr_shrink);
} else if (policy == "compositeCosine") {
T start_warmup_multiplier = this->template GetSingleArgument<float>(
arg_prefix + "start_warmup_multiplier", 0.1);
int64_t constant_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "constant_warmup_num_iter", 10000000);
int64_t linear_warmup_num_iter = this->template GetSingleArgument<int64_t>(
arg_prefix + "linear_warmup_num_iter", 10000000);
T cosine_max_lr = this->template GetSingleArgument<float>(
arg_prefix + "cosine_max_lr", 0.5);
T cosine_min_lr = this->template GetSingleArgument<float>(
arg_prefix + "cosine_min_lr", 0.1);
int64_t cosine_period = this->template GetSingleArgument<int>(
arg_prefix + "cosine_period", 50);
T cosine_t_mult = this->template GetSingleArgument<float>(
arg_prefix + "cosine_t_mult", 1.0);
T cosine_lr_shrink = this->template GetSingleArgument<float>(
arg_prefix + "cosine_lr_shrink", 0.99);
TORCH_DCHECK_GE(cosine_max_lr, cosine_min_lr);
return new CompositeCosineLearningRate<T>(
start_warmup_multiplier,
constant_warmup_num_iter,
linear_warmup_num_iter,
cosine_min_lr,
cosine_max_lr,
cosine_period,
cosine_t_mult,
cosine_lr_shrink);
} else {
CAFFE_THROW("Unknown learning rate policy: ", policy);
return NULL;
}
}
};
} // namespace caffe2
#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_