pytorch/caffe2/opt/backend_transformer_base.h

100 lines
2.7 KiB
C++

#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/workspace.h"
#include "caffe2/opt/bound_shape_inferencer.h"
#include "caffe2/proto/caffe2_pb.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace {
constexpr char kNetPos[] = "net_pos";
constexpr char kModelId[] = "model_id";
} // namespace
struct BackendTransformOptions {
explicit BackendTransformOptions() : bound_shape_spec(0, 0) {}
// Enable debugging by dumping more intermediate graphs
bool debug{false};
// Minimum number of ops to create a backend op. If the subgraph is too
// small, it doesn't make sense to lower it to backend.
size_t min_ops{1};
// Bound shape spec
BoundShapeSpec bound_shape_spec;
};
// Wrap TensorShape into TensorProto
TensorProto wrapShapeInfoIntoTensorProto(
const std::string& name,
const ShapeInfo& shape_info);
// Wrap Quantized TensorShape into QTensorProto
QTensorProto wrapShapeInfoIntoQTensorProto(
const std::string& name,
const ShapeInfo& shape_info);
// This class contains some common functions for backend lowering and graph
// cutting
class BackendTransformerBase {
public:
BackendTransformerBase() {}
virtual ~BackendTransformerBase() {}
const std::unordered_map<std::string, std::string>& input_mapping() const {
return input_mapping_;
}
const std::unordered_map<std::string, std::string>& reverse_input_mapping()
const {
return reverse_input_mapping_;
}
virtual void transform(
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& weight_names,
const ShapeInfoMap& shape_hints,
const std::unordered_set<int>& blocklisted_ops) = 0;
static void annotateOpIndex(NetDef* net);
// Get model ID from the NetDef
static std::string getModelId(const NetDef& net);
protected:
// add shape info to the net
void addShapeToNet(NetDef& shape_net, const ShapeInfoMap& shape_hints) const;
// Dump the net with shape info
void dumpNet(
const NetDef& pred_net,
const ShapeInfoMap& map,
const std::string& fname) const;
// SSA rewrite the net and return name mapping
ShapeInfoMap ssaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const ShapeInfoMap& input_shape_hints);
// Do bound shape inference and collect shape infos
ShapeInfoMap inferShapes(
Workspace* ws,
NetDef* pred_net,
const ShapeInfoMap& shape_hints_mapped,
const BoundShapeSpec& spec);
// Input mapping of input name -> original input name
std::unordered_map<std::string, std::string> input_mapping_;
// Input mapping of original input name -> input name
std::unordered_map<std::string, std::string> reverse_input_mapping_;
};
} // namespace caffe2