120 lines
3.0 KiB
C++
120 lines
3.0 KiB
C++
#include <unordered_set>
|
|
|
|
#include "caffe2/core/db.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/utils/proto_utils.h"
|
|
|
|
namespace caffe2 {
|
|
namespace db {
|
|
|
|
class ProtoDBCursor : public Cursor {
|
|
public:
|
|
explicit ProtoDBCursor(const TensorProtos* proto) : proto_(proto), iter_(0) {}
|
|
// NOLINTNEXTLINE(modernize-use-equals-default)
|
|
~ProtoDBCursor() override {}
|
|
|
|
void Seek(const string& /*str*/) override {
|
|
CAFFE_THROW("ProtoDB is not designed to support seeking.");
|
|
}
|
|
|
|
void SeekToFirst() override {
|
|
iter_ = 0;
|
|
}
|
|
void Next() override {
|
|
++iter_;
|
|
}
|
|
string key() override {
|
|
return proto_->protos(iter_).name();
|
|
}
|
|
string value() override {
|
|
return SerializeAsString_EnforceCheck(
|
|
proto_->protos(iter_), "ProtoDBCursor");
|
|
}
|
|
bool Valid() override {
|
|
return iter_ < proto_->protos_size();
|
|
}
|
|
|
|
private:
|
|
const TensorProtos* proto_;
|
|
int iter_;
|
|
};
|
|
|
|
class ProtoDBTransaction : public Transaction {
|
|
public:
|
|
explicit ProtoDBTransaction(TensorProtos* proto)
|
|
: proto_(proto), existing_names_() {
|
|
for (const auto& tensor : proto_->protos()) {
|
|
existing_names_.insert(tensor.name());
|
|
}
|
|
}
|
|
~ProtoDBTransaction() override {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
Commit();
|
|
}
|
|
void Put(const string& key, string&& value) override {
|
|
if (existing_names_.count(key)) {
|
|
CAFFE_THROW("An item with key ", key, " already exists.");
|
|
}
|
|
auto* tensor = proto_->add_protos();
|
|
CAFFE_ENFORCE(
|
|
tensor->ParseFromString(value),
|
|
"Cannot parse content from the value string.");
|
|
CAFFE_ENFORCE(
|
|
tensor->name() == key,
|
|
"Passed in key ",
|
|
key,
|
|
" does not equal to the tensor name ",
|
|
tensor->name());
|
|
}
|
|
// Commit does nothing. The protocol buffer will be written at destruction
|
|
// of ProtoDB.
|
|
void Commit() override {}
|
|
|
|
private:
|
|
TensorProtos* proto_;
|
|
std::unordered_set<string> existing_names_;
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(ProtoDBTransaction);
|
|
};
|
|
|
|
class ProtoDB : public DB {
|
|
public:
|
|
ProtoDB(const string& source, Mode mode)
|
|
: DB(source, mode), proto_(), source_(source) {
|
|
if (mode == READ || mode == WRITE) {
|
|
// Read the current protobuffer.
|
|
CAFFE_ENFORCE(
|
|
ReadProtoFromFile(source, &proto_), "Cannot read protobuffer.");
|
|
}
|
|
LOG(INFO) << "Opened protodb " << source;
|
|
}
|
|
~ProtoDB() override {
|
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
|
Close();
|
|
}
|
|
|
|
void Close() override {
|
|
if (mode_ == NEW || mode_ == WRITE) {
|
|
WriteProtoToBinaryFile(proto_, source_);
|
|
}
|
|
}
|
|
|
|
unique_ptr<Cursor> NewCursor() override {
|
|
return make_unique<ProtoDBCursor>(&proto_);
|
|
}
|
|
unique_ptr<Transaction> NewTransaction() override {
|
|
return make_unique<ProtoDBTransaction>(&proto_);
|
|
}
|
|
|
|
private:
|
|
TensorProtos proto_;
|
|
string source_;
|
|
};
|
|
|
|
REGISTER_CAFFE2_DB(ProtoDB, ProtoDB);
|
|
// For lazy-minded, one can also call with lower-case name.
|
|
REGISTER_CAFFE2_DB(protodb, ProtoDB);
|
|
|
|
} // namespace db
|
|
} // namespace caffe2
|