pytorch/caffe2/transforms/common_subexpression_elimin...

161 lines
4.9 KiB
C++

#include "caffe2/transforms/common_subexpression_elimination.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2_pb.h"
#include <c10/util/irange.h>
namespace caffe2 {
using transform::Graph;
using transform::Node;
// Checks if the node at model_idx and the node at candidate_idx are
// "common subexpressions". That is, do they have the same function, and
// take in the exact same input. If so, then their function is duplicated.
bool are_nodes_common(const Graph& g, int model_idx, int candidate_idx) {
// We need the candidate operator to match this model_op.
const Node& model_node = g.node(model_idx);
const Node& candidate_node = g.node(candidate_idx);
// Types need to match.
if (model_node.op.type() != candidate_node.op.type()) {
return false;
}
// Arguments need to match.
if (!MatchArguments(model_node.op, candidate_node.op)) {
return false;
}
// Inputs need to match.
if (model_node.op.input_size() != candidate_node.op.input_size()) {
return false;
}
// If any input_blob name is different, this is not okay.
for (int i = 0; i < model_node.op.input_size(); i++) {
if (candidate_node.op.input(i) != model_node.op.input(i)) {
return false;
}
}
// Now, we also need to check that each blob comes from the same parent, or
// if they are external (isn't in parents). This is equivalent to a
// map equality (since parent edges can only contain up to one blob).
if (model_node.parents.size() != candidate_node.parents.size() ||
!std::equal(
model_node.parents.begin(),
model_node.parents.end(),
candidate_node.parents.begin())) {
return false;
}
// Output size have to match too.
if (model_node.op.output_size() != candidate_node.op.output_size()) {
return false;
}
return true;
}
bool CommonSubexpressionEliminationTransform::PatternRule(
const Graph& g,
const std::vector<int>& subgraph,
int idx) {
if (subgraph.size() == 0) {
if (IsAllowed(g.node(idx).op.type()))
return true;
return false;
}
return are_nodes_common(g, subgraph.at(0), idx);
}
// As long as we have matched more than 2 ops, it is worth eliminating.
bool CommonSubexpressionEliminationTransform::ValidatorRule(
const Graph& /*g*/,
const std::vector<int>& subgraph) {
if (subgraph.size() >= 2) {
return true;
}
return false;
}
bool CommonSubexpressionEliminationTransform::ReplaceRule(
const std::vector<int>& subgraph,
Graph* g_ptr) {
CHECK(g_ptr);
auto& g = *g_ptr;
// We're gonna make a new node, with the same input as all of the ones in
// subgraph, but with their combined children.
int new_idx = g.size();
OperatorDef new_op = g.node(subgraph[0]).op;
// We will need to rename the output blobs.
new_op.clear_output();
for (const auto& blob : g.node(subgraph[0]).op.output()) {
new_op.add_output("transform/" + blob);
}
// Need to set up the parents.
const auto& new_op_parents = g.node(subgraph[0]).parents;
for (auto& parent : new_op_parents) {
int parent_idx = parent.first;
// Make the parents acknowledge us as its new child.
g.node(parent_idx).children[new_idx] = new_op_parents.at(parent_idx);
// Make the parents disown all our outdated siblings.
for (const auto i : c10::irange(subgraph.size())) {
g.node(parent_idx).children.erase(subgraph[i]);
}
}
// Add the node now.
g.push_node(
Node(new_op, true, new_op_parents, std::map<int, std::vector<string>>()));
// Now, we need to populate the child edges.
for (const int x : subgraph) {
// Figure out what the subgraph's node's blobs correspond to in new_op
// This is easy, since their indices match.
std::map<string, string> output_renamings;
for (int i = 0; i < new_op.output_size(); i++) {
output_renamings[g.node(x).op.output(i)] = g.node(new_idx).op.output(i);
}
// Now, time to add the old node's children to new_op
for (auto& child : g.node(x).children) {
int child_idx = child.first;
std::vector<string> blobs = child.second;
// rename the old blobs, and use them for our new edge.
for (string& blob : blobs) {
blob = output_renamings.at(blob);
}
// create this new edge
g.node(new_idx).children[child_idx] = blobs;
g.node(child_idx).parents[new_idx] = blobs;
// delete the old edge
g.node(child_idx).parents.erase(x);
// need to rename the inputs of the children too.
for (int i = 0; i < g.node(child_idx).op.input_size(); i++) {
string blob = g.node(child_idx).op.input(i);
if (output_renamings.count(blob) > 0) {
g.node(child_idx).op.set_input(i, output_renamings.at(blob));
}
}
}
}
g.DeactivateSubgraph(subgraph);
return true;
}
REGISTER_TRANSFORM(
CommonSubexpressionElimination,
CommonSubexpressionEliminationTransform);
} // namespace caffe2