312 lines
9.7 KiB
C++
312 lines
9.7 KiB
C++
#ifndef LSTM_OP_H_
|
|
#define LSTM_OP_H_
|
|
|
|
#include <algorithm>
|
|
#include <sstream>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include "caffe2/core/blob_serialization.h"
|
|
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
|
#include <c10/util/irange.h>
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/utils/eigen_utils.h"
|
|
#include "caffe2/utils/math.h"
|
|
#include "lstm_utils.h"
|
|
|
|
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LSTMOp);
|
|
|
|
namespace caffe2 {
|
|
namespace {
|
|
|
|
using t_tuple = std::tuple<Tensor, Tensor>;
|
|
|
|
struct CellParams {
|
|
CellParams(
|
|
const Tensor& _w_ih,
|
|
const Tensor& _w_hh,
|
|
const Tensor& _b_ih,
|
|
const Tensor& _b_hh,
|
|
CPUContext* _context) {
|
|
initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
|
|
}
|
|
|
|
CellParams(const CellParams& rhs) {
|
|
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
|
|
}
|
|
|
|
CellParams& operator=(const CellParams& rhs) {
|
|
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
|
|
return *this;
|
|
}
|
|
|
|
void initParams(
|
|
const Tensor& _w_ih,
|
|
const Tensor& _w_hh,
|
|
const Tensor& _b_ih,
|
|
const Tensor& _b_hh,
|
|
CPUContext* _context) {
|
|
w_ih = copy_ctor(_w_ih);
|
|
w_hh = copy_ctor(_w_hh);
|
|
b_ih = copy_ctor(_b_ih);
|
|
b_hh = copy_ctor(_b_hh);
|
|
context = _context;
|
|
}
|
|
|
|
Tensor w_ih;
|
|
Tensor w_hh;
|
|
Tensor b_ih; /* optional */
|
|
Tensor b_hh; /* optional */
|
|
CPUContext* context;
|
|
|
|
Tensor linear_ih(const Tensor& input) const {
|
|
return linear(input, w_ih, b_ih, context);
|
|
}
|
|
Tensor linear_hh(const Tensor& h) const {
|
|
return linear(h, w_hh, b_hh, context);
|
|
}
|
|
};
|
|
|
|
struct LSTMCell {
|
|
explicit LSTMCell(CPUContext* context) : context_(context) {}
|
|
t_tuple operator()(
|
|
const Tensor& input,
|
|
const t_tuple& hidden,
|
|
const CellParams& params) const {
|
|
const auto& hx = std::get<0>(hidden);
|
|
const auto& cx = std::get<1>(hidden);
|
|
auto linear_ih = params.linear_ih(input);
|
|
auto linear_hh = params.linear_hh(hx);
|
|
auto gates = add(linear_ih, linear_hh, context_);
|
|
auto chunked_gates = chunk(gates, 4, 1, context_);
|
|
auto ingate = sigmoid(chunked_gates[0]);
|
|
auto forgetgate = sigmoid(chunked_gates[1]);
|
|
auto cellgate = tanh(chunked_gates[2], context_);
|
|
auto outgate = sigmoid(chunked_gates[3]);
|
|
|
|
auto cy =
|
|
add(mul(forgetgate, cx, context_),
|
|
mul(ingate, cellgate, context_),
|
|
context_);
|
|
auto hy = mul(outgate, tanh(cy, context_), context_);
|
|
return std::make_tuple(std::move(hy), std::move(cy));
|
|
}
|
|
CPUContext* context_;
|
|
};
|
|
|
|
template <typename output_type, typename hidden_type>
|
|
struct LayerOutput {
|
|
output_type outputs;
|
|
hidden_type final_hidden;
|
|
|
|
LayerOutput(const output_type& _outputs, const hidden_type& _hidden) {
|
|
outputs = copy_ctor(_outputs);
|
|
final_hidden = copy_ctor(_hidden);
|
|
}
|
|
};
|
|
|
|
template <typename hidden_type, typename param_type>
|
|
struct Layer {
|
|
using output_type = LayerOutput<Tensor, hidden_type>;
|
|
virtual ~Layer() {}
|
|
virtual output_type operator()(
|
|
const Tensor& input,
|
|
const hidden_type& input_hidden,
|
|
const param_type& params) const = 0;
|
|
};
|
|
|
|
struct FullLSTMLayer : Layer<t_tuple, CellParams> {
|
|
FullLSTMLayer(LSTMCell& cell, CPUContext* context)
|
|
: cell_(cell), context_(context) {}
|
|
|
|
LayerOutput<std::vector<Tensor>, t_tuple> operator()(
|
|
const std::vector<Tensor>& step_inputs,
|
|
const std::tuple<Tensor, Tensor>& input_hidden,
|
|
const CellParams& params) const {
|
|
std::vector<Tensor> step_outputs;
|
|
auto hidden = copy_ctor(input_hidden);
|
|
|
|
for (const auto i : c10::irange(step_inputs.size())) {
|
|
hidden = cell_(step_inputs[i], hidden, params);
|
|
step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
|
|
}
|
|
|
|
return {step_outputs, hidden};
|
|
}
|
|
|
|
LayerOutput<Tensor, t_tuple> operator()(
|
|
const Tensor& inputs,
|
|
const std::tuple<Tensor, Tensor>& input_hidden,
|
|
const CellParams& params) const override {
|
|
auto unstacked_output =
|
|
(*this)(unbind(inputs, 0, context_), input_hidden, params);
|
|
return {stack(unstacked_output.outputs, 0, context_),
|
|
unstacked_output.final_hidden};
|
|
}
|
|
LSTMCell cell_;
|
|
CPUContext* context_;
|
|
};
|
|
|
|
struct FullBidirectionalLSTMLayer
|
|
: Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
|
|
using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
|
|
using param_type = std::pair<CellParams, CellParams>;
|
|
using output_type = LayerOutput<Tensor, bidir_hidden_type>;
|
|
|
|
FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
|
|
: layer_(cell, context), context_(context) {}
|
|
|
|
output_type operator()(
|
|
const Tensor& input,
|
|
const bidir_hidden_type& input_hidden,
|
|
const param_type& params) const override {
|
|
std::vector<Tensor> outputs;
|
|
auto step_inputs = unbind(input, 0, context_);
|
|
auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
|
|
auto fw_output = stack(fw_result.outputs, 0, context_);
|
|
outputs.push_back(copy_ctor(fw_output));
|
|
auto rev_step_inputs = reverse(std::move(step_inputs));
|
|
auto rev_result =
|
|
layer_(rev_step_inputs, input_hidden.second, params.second);
|
|
std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
|
|
auto rev_output = stack(rev_result.outputs, 0, context_);
|
|
outputs.push_back(copy_ctor(rev_output));
|
|
return {cat(outputs, fw_output.dim() - 1, context_),
|
|
std::make_pair(
|
|
std::move(fw_result.final_hidden),
|
|
std::move(rev_result.final_hidden))};
|
|
}
|
|
|
|
inline std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
|
|
std::reverse(x.begin(), x.end());
|
|
return std::move(x);
|
|
}
|
|
|
|
private:
|
|
FullLSTMLayer layer_;
|
|
CPUContext* context_;
|
|
};
|
|
|
|
template <typename hidden_type, typename weight_type>
|
|
LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
|
|
const Layer<hidden_type, weight_type>& layer,
|
|
const Tensor& input,
|
|
const std::vector<hidden_type>& hiddens,
|
|
const std::vector<weight_type>& weights,
|
|
int64_t num_layers) {
|
|
CAFFE_ENFORCE(
|
|
num_layers == hiddens.size(),
|
|
"Expected more hidden states in stacked_rnn");
|
|
CAFFE_ENFORCE(
|
|
num_layers == weights.size(), "Expected more weights in stacked_rnn");
|
|
|
|
auto layer_input = input.UnsafeSharedInstance();
|
|
auto hidden_it = hiddens.begin();
|
|
auto weight_it = weights.begin();
|
|
std::vector<hidden_type> final_hiddens(num_layers);
|
|
for (const auto l : c10::irange(num_layers)) {
|
|
auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
|
|
final_hiddens.at(l) = std::move(layer_output.final_hidden);
|
|
layer_input = std::move(layer_output.outputs);
|
|
}
|
|
return {layer_input, final_hiddens};
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
|
|
const Tensor& input,
|
|
const std::vector<CellParams>& params,
|
|
const Tensor& hx,
|
|
const Tensor& cx,
|
|
int64_t num_layers,
|
|
bool bidirectional,
|
|
CPUContext* context) {
|
|
using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
|
|
auto layer_hx = unbind(hx, 0, context);
|
|
auto layer_cx = unbind(cx, 0, context);
|
|
int64_t total_layers = layer_hx.size();
|
|
std::vector<std::tuple<Tensor, Tensor>> hiddens;
|
|
hiddens.reserve(total_layers);
|
|
for (const auto i : c10::irange(total_layers)) {
|
|
hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
|
|
}
|
|
LSTMCell cell(context);
|
|
std::shared_ptr<stack_output> stack_output_ptr;
|
|
if (bidirectional) {
|
|
auto bidir_result = apply_layer_stack(
|
|
FullBidirectionalLSTMLayer{cell, context},
|
|
input,
|
|
pair_vec(hiddens),
|
|
pair_vec(params),
|
|
num_layers);
|
|
stack_output_ptr.reset(new stack_output(
|
|
bidir_result.outputs,
|
|
unpair_vec(std::move(bidir_result.final_hidden))));
|
|
} else {
|
|
auto result = apply_layer_stack(
|
|
FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
|
|
stack_output_ptr = std::make_shared<stack_output>(std::move(result));
|
|
}
|
|
|
|
std::vector<Tensor> hy, cy;
|
|
hy.reserve(total_layers);
|
|
cy.reserve(total_layers);
|
|
for (auto& hidden : stack_output_ptr->final_hidden) {
|
|
hy.push_back(std::move(std::get<0>(hidden)));
|
|
cy.push_back(std::move(std::get<1>(hidden)));
|
|
}
|
|
return std::make_tuple(
|
|
std::move(stack_output_ptr->outputs),
|
|
stack(hy, 0, context),
|
|
stack(cy, 0, context));
|
|
}
|
|
|
|
// Parses a flat list of parameter tensors into a list of CellParams
|
|
std::vector<CellParams> gather_params(
|
|
const std::vector<Tensor>& params,
|
|
bool has_biases,
|
|
CPUContext* context) {
|
|
Tensor undefined;
|
|
std::vector<CellParams> result;
|
|
if (has_biases) {
|
|
CAFFE_ENFORCE_EQ(
|
|
params.size() % 4, 0, "got an incorrect number of LSTM parameters");
|
|
for (size_t i = 0; i < params.size(); i += 4) {
|
|
result.emplace_back(
|
|
params[i], params[i + 1], params[i + 2], params[i + 3], context);
|
|
}
|
|
} else {
|
|
CAFFE_ENFORCE_EQ(
|
|
params.size() % 2, 0, "got an incorrect number of LSTM parameters");
|
|
for (size_t i = 0; i < params.size(); i += 2) {
|
|
result.emplace_back(
|
|
params[i], params[i + 1], undefined, undefined, context);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
class InferenceLSTMOp : public Operator<CPUContext> {
|
|
public:
|
|
template <class... Args>
|
|
explicit InferenceLSTMOp(Args&&... args)
|
|
: Operator(std::forward<Args>(args)...),
|
|
num_layers_(this->template GetSingleArgument<int64_t>("num_layers", 1)),
|
|
bidirectional_(
|
|
this->template GetSingleArgument<bool>("bidirectional", false)),
|
|
has_biases_(this->template GetSingleArgument<bool>("has_biases", true)),
|
|
batch_first_(
|
|
this->template GetSingleArgument<bool>("batch_first", false)) {}
|
|
|
|
bool RunOnDevice() override;
|
|
|
|
protected:
|
|
int64_t num_layers_;
|
|
bool bidirectional_;
|
|
bool has_biases_;
|
|
bool batch_first_;
|
|
};
|
|
|
|
} // namespace
|
|
} // namespace caffe2
|
|
#endif // LSTM_OP_H_
|