116 lines
3.4 KiB
C++
116 lines
3.4 KiB
C++
#include "caffe2/core/test_utils.h"
|
|
#include "caffe2/opt/converter.h"
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
TEST(Converter, Basic) {
|
|
using namespace caffe2::testing;
|
|
caffe2::NetDef net;
|
|
for (auto i = 0; i < 10; ++i) {
|
|
// NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
|
|
if (rand() % 2) {
|
|
NetMutator(&net)
|
|
.newOp("Conv", {"X", "W" + c10::to_string(i)}, {"X"})
|
|
.addArgument("kernel", 3)
|
|
.addArgument("stride", 1)
|
|
.addArgument("pad", 0)
|
|
.addArgument("order", std::string("NCHW"))
|
|
.setDeviceOptionName("conv_runner");
|
|
} else {
|
|
NetMutator(&net)
|
|
.newOp("Relu", {"X"}, {"X"})
|
|
.setDeviceOptionName("relu_runner");
|
|
}
|
|
}
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
|
|
}
|
|
|
|
TEST(Converter, UnknownType) {
|
|
using namespace caffe2::testing;
|
|
caffe2::NetDef net;
|
|
NetMutator(&net)
|
|
.newOp("NeverSeen", {"X"}, {"X"})
|
|
.setDeviceOptionName("device_" + c10::to_string(rand() % 2));
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
|
|
}
|
|
|
|
TEST(Converter, SpecializeConverter) {
|
|
using namespace caffe2::testing;
|
|
caffe2::NetDef net;
|
|
NetMutator(&net).newOp("Slice", {"X"}, {"X"}).setDeviceOptionName("abc");
|
|
EXPECT_EQ(net.op(0).device_option().node_name(), "abc");
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
|
|
EXPECT_EQ(new_netdef.op(0).device_option().node_name(), "abc");
|
|
}
|
|
|
|
caffe2::NetDef fakeNet() {
|
|
using namespace caffe2::testing;
|
|
caffe2::NetDef net;
|
|
NetMutator(&net)
|
|
.newOp("Fake", {"X"}, {"Y"})
|
|
.newOp("Fake", {"Y"}, {"Z"})
|
|
.newOp("Fake", {"Z", "X"}, {"W"})
|
|
.externalInputs({"X"})
|
|
.externalOutputs({"Y", "W"});
|
|
return net;
|
|
}
|
|
|
|
TEST(Converter, ExternalInputs) {
|
|
auto net = fakeNet();
|
|
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
|
|
EXPECT_EQ(new_netdef.external_input().size(), net.external_input().size());
|
|
for (auto i = 0; i < net.external_input().size(); ++i) {
|
|
EXPECT_EQ(new_netdef.external_input(i), net.external_input(i));
|
|
}
|
|
}
|
|
|
|
TEST(Converter, ExternalOutputs) {
|
|
auto net = fakeNet();
|
|
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
|
|
EXPECT_EQ(new_netdef.external_output().size(), net.external_output().size());
|
|
for (auto i = 0; i < net.external_output().size(); ++i) {
|
|
EXPECT_EQ(new_netdef.external_output(i), net.external_output(i));
|
|
}
|
|
}
|
|
|
|
TEST(Converter, InjectDataEdgeIndicators) {
|
|
auto net = fakeNet();
|
|
caffe2::injectDataEdgeIndicators(&net);
|
|
|
|
EXPECT_EQ(net.op_size(), 3 + 1 + 2); // Inserted 1 Declare and 2 Export
|
|
|
|
auto declare_count = 0;
|
|
auto export_count = 0;
|
|
for (const auto& op : net.op()) {
|
|
declare_count += op.type() == "Declare";
|
|
export_count += op.type() == "Export";
|
|
}
|
|
EXPECT_EQ(declare_count, 1);
|
|
EXPECT_EQ(export_count, 2);
|
|
|
|
// Remove them from the network
|
|
EXPECT_EQ(net.external_input_size(), 0);
|
|
EXPECT_EQ(net.external_output_size(), 0);
|
|
|
|
// Ensure nomnigraph can handle this change
|
|
auto nn = caffe2::convertToNNModule(net);
|
|
auto new_net = caffe2::convertToCaffe2Proto(nn);
|
|
|
|
caffe2::removeDataEdgeIndicators(&new_net);
|
|
|
|
for (const auto& op : new_net.op()) {
|
|
EXPECT_NE(op.type(), "Declare");
|
|
EXPECT_NE(op.type(), "Export");
|
|
}
|
|
|
|
EXPECT_EQ(new_net.external_input_size(), 1);
|
|
EXPECT_EQ(new_net.external_output_size(), 2);
|
|
}
|