pytorch/caffe2/opt/converter_nomigraph_test.cc

116 lines
3.4 KiB
C++

#include "caffe2/core/test_utils.h"
#include "caffe2/opt/converter.h"
#include <gtest/gtest.h>
TEST(Converter, Basic) {
using namespace caffe2::testing;
caffe2::NetDef net;
for (auto i = 0; i < 10; ++i) {
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
if (rand() % 2) {
NetMutator(&net)
.newOp("Conv", {"X", "W" + c10::to_string(i)}, {"X"})
.addArgument("kernel", 3)
.addArgument("stride", 1)
.addArgument("pad", 0)
.addArgument("order", std::string("NCHW"))
.setDeviceOptionName("conv_runner");
} else {
NetMutator(&net)
.newOp("Relu", {"X"}, {"X"})
.setDeviceOptionName("relu_runner");
}
}
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
}
TEST(Converter, UnknownType) {
using namespace caffe2::testing;
caffe2::NetDef net;
NetMutator(&net)
.newOp("NeverSeen", {"X"}, {"X"})
.setDeviceOptionName("device_" + c10::to_string(rand() % 2));
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
}
TEST(Converter, SpecializeConverter) {
using namespace caffe2::testing;
caffe2::NetDef net;
NetMutator(&net).newOp("Slice", {"X"}, {"X"}).setDeviceOptionName("abc");
EXPECT_EQ(net.op(0).device_option().node_name(), "abc");
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
EXPECT_EQ(new_netdef.op(0).device_option().node_name(), "abc");
}
caffe2::NetDef fakeNet() {
using namespace caffe2::testing;
caffe2::NetDef net;
NetMutator(&net)
.newOp("Fake", {"X"}, {"Y"})
.newOp("Fake", {"Y"}, {"Z"})
.newOp("Fake", {"Z", "X"}, {"W"})
.externalInputs({"X"})
.externalOutputs({"Y", "W"});
return net;
}
TEST(Converter, ExternalInputs) {
auto net = fakeNet();
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
EXPECT_EQ(new_netdef.external_input().size(), net.external_input().size());
for (auto i = 0; i < net.external_input().size(); ++i) {
EXPECT_EQ(new_netdef.external_input(i), net.external_input(i));
}
}
TEST(Converter, ExternalOutputs) {
auto net = fakeNet();
auto nn = caffe2::convertToNNModule(net);
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
EXPECT_EQ(new_netdef.external_output().size(), net.external_output().size());
for (auto i = 0; i < net.external_output().size(); ++i) {
EXPECT_EQ(new_netdef.external_output(i), net.external_output(i));
}
}
TEST(Converter, InjectDataEdgeIndicators) {
auto net = fakeNet();
caffe2::injectDataEdgeIndicators(&net);
EXPECT_EQ(net.op_size(), 3 + 1 + 2); // Inserted 1 Declare and 2 Export
auto declare_count = 0;
auto export_count = 0;
for (const auto& op : net.op()) {
declare_count += op.type() == "Declare";
export_count += op.type() == "Export";
}
EXPECT_EQ(declare_count, 1);
EXPECT_EQ(export_count, 2);
// Remove them from the network
EXPECT_EQ(net.external_input_size(), 0);
EXPECT_EQ(net.external_output_size(), 0);
// Ensure nomnigraph can handle this change
auto nn = caffe2::convertToNNModule(net);
auto new_net = caffe2::convertToCaffe2Proto(nn);
caffe2::removeDataEdgeIndicators(&new_net);
for (const auto& op : new_net.op()) {
EXPECT_NE(op.type(), "Declare");
EXPECT_NE(op.type(), "Export");
}
EXPECT_EQ(new_net.external_input_size(), 1);
EXPECT_EQ(new_net.external_output_size(), 2);
}