pytorch/caffe2/video/video_input_op.h

1025 lines
34 KiB
C++

#ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
#include <exception>
#include <istream>
#include <ostream>
#include <random>
#include <string>
#include <c10/core/thread_pool.h>
#include <c10/util/irange.h>
#include <caffe2/core/db.h>
#include <caffe2/core/logging.h>
#include <caffe2/operators/prefetch_op.h>
#include <caffe2/utils/math.h>
#include <caffe2/video/video_decoder.h>
#include <caffe2/video/video_io.h>
namespace caffe2 {
template <class Context>
class VideoInputOp final : public PrefetchOperator<Context> {
public:
using OperatorBase::OutputSize;
using PrefetchOperator<Context>::context_;
using PrefetchOperator<Context>::prefetch_thread_;
explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
~VideoInputOp() {
PrefetchOperator<Context>::Finalize();
}
// override methods
bool Prefetch() override;
bool CopyPrefetched() override;
private:
void CheckParamsAndPrint();
bool GetClipsAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen);
void DecodeAndTransform(
const std::string& value,
float* clip_rgb_data,
float* clip_of_data,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen,
std::bernoulli_distribution* mirror_this_clip);
void GetLabelsFromProto(const TensorProto& label_proto, int* label_data);
bool GetImageAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data);
const db::DBReader* reader_;
CPUContext cpu_context_;
Tensor prefetched_clip_rgb_;
Tensor prefetched_clip_of_;
Tensor prefetched_label_;
Tensor prefetched_video_id_;
Tensor prefetched_start_frame_;
Tensor prefetched_clip_rgb_on_device_{Context::GetDeviceType()};
Tensor prefetched_clip_of_on_device_{Context::GetDeviceType()};
Tensor prefetched_label_on_device_{Context::GetDeviceType()};
Tensor prefetched_video_id_on_device_{Context::GetDeviceType()};
Tensor prefetched_start_frame_on_device_{Context::GetDeviceType()};
int batch_size_;
int clip_per_video_;
std::vector<int> clip_start_positions_;
std::vector<float> mean_rgb_;
std::vector<float> inv_std_rgb_;
std::vector<float> mean_of_;
std::vector<float> inv_std_of_;
int channels_rgb_;
int channels_of_;
int crop_size_;
int scale_h_;
int scale_w_;
int short_edge_;
std::vector<int> jitter_scales_;
int length_rgb_;
int sampling_rate_rgb_;
int random_sampling_rate_;
int num_of_required_frame_;
int length_of_;
int sampling_rate_of_;
int frame_gap_of_;
bool random_mirror_;
int num_of_class_;
bool use_local_file_;
bool random_crop_;
int crop_per_clip_;
int flow_data_type_;
int flow_alg_type_;
int decode_type_;
int video_res_type_;
bool do_flow_aggregation_;
bool image_as_input_;
bool get_rgb_;
bool get_optical_flow_;
bool get_video_id_;
bool get_start_frame_;
bool do_multi_label_;
// thread pool for parse + decode
int num_decode_threads_;
std::shared_ptr<TaskThreadPool> thread_pool_;
};
template <class Context>
void VideoInputOp<Context>::CheckParamsAndPrint() {
// check whether the input parameters are valid or not
CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
CAFFE_ENFORCE_GT(
clip_per_video_, 0, "Number of clips per video should be positive.");
CAFFE_ENFORCE_GT(crop_size_, 0, "Must provide the cropping value.");
if (!image_as_input_) {
CAFFE_ENFORCE_GT(
num_of_required_frame_,
0,
"Required number of frames must be positive.");
}
if (image_as_input_) {
CAFFE_ENFORCE_EQ(
video_res_type_,
VideoResType::USE_WIDTH_HEIGHT,
"Currently only USE_WIDTH_HEIGHT option is supported with images");
}
if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
CAFFE_ENFORCE_GT(short_edge_, 0, "Must provide the short edge value.");
CAFFE_ENFORCE_GE(
short_edge_,
crop_size_,
"The short edge must be no smaller than the crop value.");
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
CAFFE_ENFORCE_GE(
scale_h_,
crop_size_,
"The scaled height must be no smaller than the crop value.");
CAFFE_ENFORCE_GE(
scale_w_,
crop_size_,
"The scaled width must be no smaller than the crop value.");
}
if (jitter_scales_.size() > 0) {
CAFFE_ENFORCE_GE(
video_res_type_,
VideoResType::USE_SHORT_EDGE,
"Scale jittering is used with short_edge scaling only");
}
if (get_rgb_) {
CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
}
if (get_optical_flow_) {
CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
CAFFE_ENFORCE_GT(
sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
CAFFE_ENFORCE_EQ(
channels_of_,
mean_of_.size(),
"Number of optical flow channels is wrong!");
CAFFE_ENFORCE_EQ(
channels_of_,
inv_std_of_.size(),
"Number of optical flow channels is wrong!");
}
if (clip_per_video_ > 1) {
CAFFE_ENFORCE_EQ(
decode_type_,
DecodeType::DO_UNIFORM_SMP,
"Only uniformly sampling is supported when sampling multiple clips!");
}
if (do_multi_label_) {
CAFFE_ENFORCE_GT(
num_of_class_,
0,
"Number of classes must be set when using multiple labels.");
}
// print out the parameter settings
LOG(INFO) << "Creating a clip input op with the following setting: ";
LOG(INFO) << " Input Type: " << (image_as_input_ ? "Image" : "Video");
LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
LOG(INFO) << " Cropping video frame to " << crop_size_
<< (random_mirror_ ? " with " : " without ") << "random mirroring;";
LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
LOG(INFO) << " Using " << crop_per_clip_ << " spatial crop(s)";
if (get_rgb_) {
LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
<< "with " << channels_rgb_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_rgb_;
if (random_sampling_rate_) {
LOG(INFO) << "random sampling with max:" << random_sampling_rate_;
}
for (const auto i : c10::irange(channels_rgb_)) {
LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
<< " std: " << 1.f / inv_std_rgb_[i];
}
}
if (get_optical_flow_) {
LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
<< "with " << channels_of_ << " channels "
<< "and a sampling rate of 1:" << sampling_rate_of_
<< " flow_data_type_: " << flow_data_type_
<< " flow_alg_type_: " << flow_alg_type_;
for (const auto i : c10::irange(channels_of_)) {
LOG(INFO) << " Optical flow" << i
<< "-th channel mean: " << mean_of_[i]
<< " std: " << 1.f / inv_std_of_[i];
}
}
if (video_res_type_ == VideoResType::ORIGINAL_RES) {
LOG(INFO) << " Use original resolution";
} else if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
LOG(INFO) << " Resize and keep aspect ratio";
} else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
LOG(INFO) << " Resize and ignore aspect ratio";
} else {
LOG(ERROR) << " Unknown video resolution type";
}
if (video_res_type_ == VideoResType::USE_SHORT_EDGE) {
if (jitter_scales_.size() > 0) {
LOG(INFO) << "Using scale jittering:";
for (const auto idx : c10::irange(jitter_scales_.size())) {
LOG(INFO) << "scale " << idx << ": " << jitter_scales_[idx];
}
} else {
LOG(INFO) << "No scale jittering is used.";
}
}
if (decode_type_ == DecodeType::DO_TMP_JITTER) {
LOG(INFO) << " Do temporal jittering";
} else if (decode_type_ == DecodeType::USE_START_FRM) {
LOG(INFO) << " Use start_frm for decoding";
} else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
LOG(INFO) << " Do uniformly sampling";
} else {
LOG(ERROR) << " Unknown video decoding type";
}
if (get_start_frame_) {
CAFFE_ENFORCE_EQ(
decode_type_,
DecodeType::USE_START_FRM,
"Only decoding with starting frame is supported w/ get start_frame!");
CAFFE_ENFORCE_EQ(
clip_per_video_, 1, "get start frame support only clip per video = 1");
}
}
template <class Context>
VideoInputOp<Context>::VideoInputOp(
const OperatorDef& operator_def,
Workspace* ws)
: PrefetchOperator<Context>(operator_def, ws),
reader_(nullptr),
batch_size_(
OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
clip_per_video_(
OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
clip_start_positions_(OperatorBase::template GetRepeatedArgument<int>(
"clip_start_positions",
{})),
channels_rgb_(
OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
channels_of_(
OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
crop_size_(OperatorBase::template GetSingleArgument<int>("crop_size", 0)),
scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
short_edge_(
OperatorBase::template GetSingleArgument<int>("short_edge", 0)),
jitter_scales_(
OperatorBase::template GetRepeatedArgument<int>("jitter_scales", {})),
length_rgb_(
OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
"sampling_rate_rgb",
1)),
random_sampling_rate_(OperatorBase::template GetSingleArgument<int>(
"random_sampling_rate",
0)),
length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
sampling_rate_of_(
OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
frame_gap_of_(
OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
random_mirror_(OperatorBase::template GetSingleArgument<bool>(
"random_mirror",
true)),
num_of_class_(
OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
use_local_file_(OperatorBase::template GetSingleArgument<bool>(
"use_local_file",
false)),
random_crop_(
OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
crop_per_clip_(
OperatorBase::template GetSingleArgument<int>("crop_per_clip", 1)),
flow_data_type_(
OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
flow_alg_type_(
OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
decode_type_(
OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
video_res_type_(
OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
"do_flow_aggregation",
true)),
image_as_input_(OperatorBase::template GetSingleArgument<bool>(
"image_as_input",
false)),
get_rgb_(OperatorBase::template GetSingleArgument<bool>("get_rgb", true)),
get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
"get_optical_flow",
false)),
get_video_id_(OperatorBase::template GetSingleArgument<bool>(
"get_video_id",
false)),
get_start_frame_(OperatorBase::template GetSingleArgument<bool>(
"get_start_frame",
false)),
do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
"do_multi_label",
false)),
num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
"num_decode_threads",
4)),
thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
try {
num_of_required_frame_ = 0;
// mean and std for normalizing different optical flow data type;
// Example statistics generated from SOA are shown below, and you may
// want to change them if you are running on a different dataset;
// 7 channels: (flow_x, flow_y, flow_magitude, gray, Red, Green, Blue)
const std::vector<float> InputDataMean = {
0.0046635, 0.0046261, 0.963986, 102.976, 110.201, 100.64, 95.9966};
const std::vector<float> InputDataStd = {
0.972347, 0.755146, 1.43588, 55.3691, 58.1489, 56.4701, 55.3324};
// if we need RGB as an input
if (get_rgb_ && !image_as_input_) {
// how many frames we need for RGB
num_of_required_frame_ = std::max(
num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
if (random_sampling_rate_) {
num_of_required_frame_ = std::max(
num_of_required_frame_,
(length_rgb_ - 1) * random_sampling_rate_ + 1);
}
channels_rgb_ = 3;
for (const auto i : c10::irange(4, 7)) {
mean_rgb_.push_back(InputDataMean[i]);
inv_std_rgb_.push_back(1.f / InputDataStd[i]);
}
}
if (image_as_input_) {
channels_rgb_ = 3;
length_rgb_ = 1;
clip_per_video_ = 1;
get_optical_flow_ = false;
get_rgb_ = true;
sampling_rate_rgb_ = 1;
for (const auto i : c10::irange(4, 7)) {
mean_rgb_.push_back(InputDataMean[i]);
inv_std_rgb_.push_back(1.f / InputDataStd[i]);
}
}
// if we need optical flow as an input
if (get_optical_flow_) {
// how many frames we need for optical flow
num_of_required_frame_ = std::max(
num_of_required_frame_,
(length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
// set the parameters for different input data types
switch (flow_data_type_) {
case FlowDataType::Flow2C:
channels_of_ = 2;
for (const auto i : c10::irange(channels_of_)) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
case FlowDataType::Flow3C:
channels_of_ = 3;
for (const auto i : c10::irange(channels_of_)) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
// early fusion with gray
case FlowDataType::FlowWithGray:
channels_of_ = 3;
for (const auto i : c10::irange(2)) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
mean_of_.push_back(InputDataMean[3]);
inv_std_of_.push_back(1.f / InputDataStd[3]);
break;
// early fusion with RGB
case FlowDataType::FlowWithRGB:
channels_of_ = 5;
for (const auto i : c10::irange(2)) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
for (const auto i : c10::irange(4, 7)) {
mean_of_.push_back(InputDataMean[i]);
inv_std_of_.push_back(1.f / InputDataStd[i]);
}
break;
default:
LOG(ERROR) << "Unknown optical flow type " << flow_data_type_;
break;
}
}
CheckParamsAndPrint();
// Always need a dbreader, even when using local video files
CAFFE_ENFORCE_GT(
operator_def.input_size(), 0, "Need to have a DBReader blob input");
vector<int64_t> data_shape(5);
vector<int64_t> label_shape(2);
// In case clip_start_positions are given, set the clip_per_video arg
if (clip_start_positions_.size() > 0) {
clip_per_video_ = clip_start_positions_.size();
}
// for RGB data
data_shape[0] = batch_size_ * clip_per_video_ * crop_per_clip_;
data_shape[1] = channels_rgb_;
data_shape[2] = length_rgb_;
data_shape[3] = crop_size_;
data_shape[4] = crop_size_;
ReinitializeTensor(
&prefetched_clip_rgb_, data_shape, at::dtype<float>().device(CPU));
// for optical flow data
data_shape[1] = channels_of_;
data_shape[2] = length_of_;
ReinitializeTensor(
&prefetched_clip_of_, data_shape, at::dtype<float>().device(CPU));
// If do_multi_label is used, output label is a binary vector
// of length num_of_class indicating which labels present
if (do_multi_label_) {
label_shape[0] = batch_size_ * clip_per_video_ * crop_per_clip_;
label_shape[1] = num_of_class_;
ReinitializeTensor(
&prefetched_label_, label_shape, at::dtype<int>().device(CPU));
} else {
ReinitializeTensor(
&prefetched_label_,
vector<int64_t>(1, batch_size_ * clip_per_video_ * crop_per_clip_),
at::dtype<int>().device(CPU));
}
ReinitializeTensor(
&prefetched_video_id_,
vector<int64_t>(1, batch_size_ * clip_per_video_ * crop_per_clip_),
at::dtype<int>().device(CPU));
ReinitializeTensor(
&prefetched_start_frame_,
vector<int64_t>(1, batch_size_ * clip_per_video_ * crop_per_clip_),
at::dtype<int>().device(CPU));
} catch (const std::exception& exc) {
std::cerr << "While calling VideoInputOp initialization\n";
std::cerr << exc.what();
}
}
template <class Context>
void VideoInputOp<Context>::GetLabelsFromProto(
const TensorProto& label_proto,
int* label_data) {
int num_clips = clip_per_video_ * crop_per_clip_;
if (!do_multi_label_) {
for (const auto i : c10::irange(num_clips)) {
label_data[i] = label_proto.int32_data(0);
}
} else {
// For multiple label case, output label is a binary vector
// where presented concepts are marked 1
memset(label_data, 0, sizeof(int) * num_of_class_ * num_clips);
for (const auto i : c10::irange(num_clips)) {
for (const auto j : c10::irange(label_proto.int32_data_size())) {
CAFFE_ENFORCE_LT(
label_proto.int32_data(j),
num_of_class_,
"Label should be less than the number of classes.");
label_data[i * num_of_class_ + label_proto.int32_data(j)] = 1;
}
}
}
}
template <class Context>
bool VideoInputOp<Context>::GetImageAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data) {
TensorProtos protos;
CAFFE_ENFORCE(protos.ParseFromString(value));
const TensorProto& image_proto = protos.protos(0);
const TensorProto& label_proto = protos.protos(1);
GetLabelsFromProto(label_proto, label_data);
cv::Mat src;
if (image_proto.data_type() == TensorProto::STRING) {
// encoded image string.
TORCH_DCHECK_EQ(image_proto.string_data_size(), 1);
const string& encoded_image_str = image_proto.string_data(0);
int encoded_size = encoded_image_str.size();
// We use a cv::Mat to wrap the encoded str so we do not need a copy.
src = cv::imdecode(
cv::Mat(
1,
&encoded_size,
CV_8UC1,
const_cast<char*>(encoded_image_str.data())),
cv::IMREAD_COLOR);
if (src.rows == 0 || src.cols == 0) {
throw std::runtime_error("Both rows and cols are 0 for image");
}
} else if (image_proto.data_type() == TensorProto::BYTE) {
// raw image content.
int src_c = (image_proto.dims_size() == 3) ? image_proto.dims(2) : 1;
CAFFE_ENFORCE(src_c == 3 || src_c == 1);
src.create(
image_proto.dims(0),
image_proto.dims(1),
(src_c == 3) ? CV_8UC3 : CV_8UC1);
memcpy(
src.ptr<uchar>(0),
image_proto.byte_data().data(),
image_proto.byte_data().size());
} else {
throw std::runtime_error(
"Unknown image data type: " +
caffe2::to_string(image_proto.data_type()));
}
CAFFE_ENFORCE(src.isContinuous());
cv::Mat scaled_img;
cv::resize(
src, scaled_img, cv::Size(scale_w_, scale_h_), 0, 0, cv::INTER_AREA);
cv::Mat img;
if (channels_rgb_ == src.channels()) {
img = scaled_img;
} else {
cv::cvtColor(
scaled_img,
img,
(channels_rgb_ == 1) ? cv::COLOR_BGR2GRAY : cv::COLOR_GRAY2BGR);
}
cv::Mat rgb_img;
if (channels_rgb_ == 1) {
cv::cvtColor(img, rgb_img, cv::COLOR_BGR2RGB);
} else {
rgb_img = img;
}
CAFFE_ENFORCE(rgb_img.isContinuous());
unsigned char* data = new unsigned char[scale_h_ * scale_w_ * channels_rgb_];
memcpy(
data,
rgb_img.data,
scale_h_ * scale_w_ * channels_rgb_ * sizeof(unsigned char));
buffer_rgb.push_back(data);
width = scale_w_;
height = scale_h_;
return true;
}
template <class Context>
bool VideoInputOp<Context>::GetClipsAndLabelsFromDBValue(
const std::string& value,
int& height,
int& width,
std::vector<unsigned char*>& buffer_rgb,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen) {
TensorProtos protos;
int curr_proto_idx = 0;
CAFFE_ENFORCE(protos.ParseFromString(value));
const TensorProto& video_proto = protos.protos(curr_proto_idx++);
const TensorProto& label_proto = protos.protos(curr_proto_idx++);
int start_frm = 0;
int num_clips = clip_per_video_ * crop_per_clip_;
// start_frm is only valid when sampling 1 clip per video without
// temporal jitterring
if (decode_type_ == DecodeType::USE_START_FRM) {
CAFFE_ENFORCE_GE(
protos.protos_size(),
curr_proto_idx + 1,
"Start frm proto not provided");
const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
start_frm = start_frm_proto.int32_data(0);
if (get_start_frame_) {
for (const auto i : c10::irange(num_clips)) {
start_frame_data[i] = start_frm;
}
}
}
if (get_video_id_) {
CAFFE_ENFORCE_GE(
protos.protos_size(), curr_proto_idx + 1, "Video Id not provided");
const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
for (const auto i : c10::irange(num_clips)) {
video_id_data[i] = video_id_proto.int64_data(0);
}
}
// assign labels
GetLabelsFromProto(label_proto, label_data);
if (use_local_file_) {
CAFFE_ENFORCE_EQ(
video_proto.data_type(),
TensorProto::STRING,
"Database with a file_list is expected to be string data");
}
// initializing the decoding params
Params params;
params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
params.video_res_type_ = video_res_type_;
params.crop_size_ = crop_size_;
params.short_edge_ = short_edge_;
params.outputWidth_ = scale_w_;
params.outputHeight_ = scale_h_;
params.decode_type_ = decode_type_;
params.num_of_required_frame_ = num_of_required_frame_;
if (jitter_scales_.size() > 0) {
int select_idx =
std::uniform_int_distribution<>(0, jitter_scales_.size() - 1)(*randgen);
params.short_edge_ = jitter_scales_[select_idx];
}
char* video_buffer = nullptr; // for decoding from buffer
std::string video_filename; // for decoding from file
int encoded_size = 0;
if (video_proto.data_type() == TensorProto::STRING) {
const string& encoded_video_str = video_proto.string_data(0);
if (!use_local_file_) {
encoded_size = encoded_video_str.size();
video_buffer = const_cast<char*>(encoded_video_str.data());
} else {
video_filename = encoded_video_str;
}
} else if (video_proto.data_type() == TensorProto::BYTE) {
if (!use_local_file_) {
encoded_size = video_proto.byte_data().size();
video_buffer = const_cast<char*>(video_proto.byte_data().data());
} else {
// TODO: does this works?
video_filename = video_proto.string_data(0);
}
} else {
CAFFE_ENFORCE(false, "Unknown video data type.");
}
DecodeMultipleClipsFromVideo(
video_buffer,
video_filename,
encoded_size,
params,
start_frm,
clip_per_video_,
clip_start_positions_,
use_local_file_,
height,
width,
buffer_rgb);
return true;
}
template <class Context>
void VideoInputOp<Context>::DecodeAndTransform(
const std::string& value,
float* clip_rgb_data,
float* clip_of_data,
int* label_data,
int64_t* video_id_data,
int* start_frame_data,
std::mt19937* randgen,
std::bernoulli_distribution* mirror_this_clip) {
try {
std::vector<unsigned char*> buffer_rgb;
// get the video resolution after decoding
int height = 0;
int width = 0;
if (image_as_input_) {
CHECK(GetImageAndLabelsFromDBValue(
value, height, width, buffer_rgb, label_data));
} else {
// Decode the video from memory or read from a local file
CHECK(GetClipsAndLabelsFromDBValue(
value,
height,
width,
buffer_rgb,
label_data,
video_id_data,
start_frame_data,
randgen));
}
int clip_offset_rgb = channels_rgb_ * length_rgb_ * crop_size_ * crop_size_;
int clip_offset_of = channels_of_ * length_of_ * crop_size_ * crop_size_;
for (int i = 0; i < std::min(clip_per_video_, int(buffer_rgb.size()));
i++) {
for (const auto j : c10::irange(crop_per_clip_)) {
// get the rectangle for cropping
int h_off = 0;
int w_off = 0;
if (crop_per_clip_ > 1) {
CAFFE_ENFORCE(
random_crop_ == false,
"Only using multiple crops w/o random cropping");
}
if (random_crop_) {
// using random crop for training
h_off =
std::uniform_int_distribution<>(0, height - crop_size_)(*randgen);
w_off =
std::uniform_int_distribution<>(0, width - crop_size_)(*randgen);
} else {
// using multiple spatial crops
if (crop_per_clip_ > 1) { // normally 3 crops
if (height < width) {
h_off = (height - crop_size_) / 2;
w_off = j * (width - crop_size_) / (crop_per_clip_ - 1);
} else {
h_off = j * (height - crop_size_) / (crop_per_clip_ - 1);
w_off = (width - crop_size_) / 2;
}
// LOG(INFO) << "crop " << j << "-th " << h_off << " & " << w_off;
} else { // using center crop for testing
h_off = (height - crop_size_) / 2;
w_off = (width - crop_size_) / 2;
}
}
cv::Rect rect(w_off, h_off, crop_size_, crop_size_);
int this_clip_sampling_rate;
if (random_sampling_rate_) {
this_clip_sampling_rate = std::uniform_int_distribution<>(
1, random_sampling_rate_)(*randgen);
}
// randomly mirror the image or not
bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
if (get_rgb_ && clip_rgb_data) {
ClipTransformRGB(
buffer_rgb[i],
crop_size_,
length_rgb_,
channels_rgb_,
(random_sampling_rate_ == 0 ? sampling_rate_rgb_
: this_clip_sampling_rate),
height,
width,
h_off,
w_off,
mirror_me,
mean_rgb_,
inv_std_rgb_,
clip_rgb_data + ((i * crop_per_clip_ + j) * clip_offset_rgb));
}
if (get_optical_flow_ && clip_of_data) {
ClipTransformOpticalFlow(
buffer_rgb[i],
crop_size_,
length_of_,
channels_of_,
sampling_rate_of_,
height,
width,
rect,
channels_rgb_,
mirror_me,
flow_alg_type_,
flow_data_type_,
frame_gap_of_,
do_flow_aggregation_,
mean_of_,
inv_std_of_,
clip_of_data + ((i * crop_per_clip_ + j) * clip_offset_of));
}
}
}
if (buffer_rgb.size() > 0) {
for (const auto i : c10::irange(buffer_rgb.size())) {
unsigned char* buff = buffer_rgb[i];
delete[] buff;
}
}
buffer_rgb.clear();
} catch (const std::exception& exc) {
std::cerr << "While calling DecodeAndTransform()\n";
std::cerr << exc.what();
}
}
template <class Context>
bool VideoInputOp<Context>::Prefetch() {
try {
// We will get the reader pointer from input.
// If we use local clips, db will store the list
reader_ = &OperatorBase::Input<db::DBReader>(0);
// Call mutable_data() once to allocate the underlying memory.
prefetched_clip_rgb_.mutable_data<float>();
prefetched_clip_of_.mutable_data<float>();
prefetched_label_.mutable_data<int>();
prefetched_video_id_.mutable_data<int64_t>();
prefetched_start_frame_.mutable_data<int>();
// Prefetching handled with a thread pool of "decode_threads" threads.
std::mt19937 meta_randgen(time(nullptr));
std::vector<std::mt19937> randgen_per_thread;
for (const auto i : c10::irange(num_decode_threads_)) {
randgen_per_thread.emplace_back(meta_randgen());
}
std::bernoulli_distribution mirror_this_clip(0.5);
for (const auto item_id : c10::irange(batch_size_)) {
std::mt19937* randgen =
&randgen_per_thread[item_id % num_decode_threads_];
int frame_size = crop_size_ * crop_size_;
// get the clip data pointer for the item_id -th example
float* clip_rgb_data = prefetched_clip_rgb_.mutable_data<float>() +
frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
crop_per_clip_;
// get the optical flow data for the current clip
float* clip_of_data = prefetched_clip_of_.mutable_data<float>() +
frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
crop_per_clip_;
// get the label data pointer for the item_id -th example
int* label_data = prefetched_label_.mutable_data<int>() +
(do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
crop_per_clip_;
// get the video id data pointer for the item_id -th example
int64_t* video_id_data = prefetched_video_id_.mutable_data<int64_t>() +
item_id * clip_per_video_ * crop_per_clip_;
int* start_frame_data = prefetched_start_frame_.mutable_data<int>() +
item_id * clip_per_video_ * crop_per_clip_;
std::string key, value;
// read data
reader_->Read(&key, &value);
thread_pool_->run(std::bind(
&VideoInputOp<Context>::DecodeAndTransform,
this,
std::string(value),
clip_rgb_data,
clip_of_data,
label_data,
video_id_data,
start_frame_data,
randgen,
&mirror_this_clip));
} // for over the batch
thread_pool_->waitWorkComplete();
// If the context is not CPUContext, we will need to do a copy in the
// prefetch function as well.
if (!std::is_same<Context, CPUContext>::value) {
if (get_rgb_) {
prefetched_clip_rgb_on_device_.CopyFrom(
prefetched_clip_rgb_, &context_);
}
if (get_optical_flow_) {
prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
}
prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
if (get_video_id_) {
prefetched_video_id_on_device_.CopyFrom(
prefetched_video_id_, &context_);
}
if (get_start_frame_) {
prefetched_start_frame_on_device_.CopyFrom(
prefetched_start_frame_, &context_);
}
}
} catch (const std::exception& exc) {
std::cerr << "While calling Prefetch()\n";
std::cerr << exc.what();
}
return true;
}
template <class Context>
bool VideoInputOp<Context>::CopyPrefetched() {
try {
int index = 0;
auto type = Context::GetDeviceType();
if (get_rgb_) {
auto* clip_rgb_output = OperatorBase::Output<Tensor>(index++, type);
if (std::is_same<Context, CPUContext>::value) {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
} else {
clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
}
}
if (get_optical_flow_) {
auto* clip_of_output = OperatorBase::Output<Tensor>(index++, type);
if (std::is_same<Context, CPUContext>::value) {
clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
} else {
clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
}
}
auto* label_output = OperatorBase::Output<Tensor>(index++, type);
if (std::is_same<Context, CPUContext>::value) {
label_output->CopyFrom(prefetched_label_, &context_);
} else {
label_output->CopyFrom(prefetched_label_on_device_, &context_);
}
if (get_video_id_) {
auto* video_id_output = OperatorBase::Output<Tensor>(index++, type);
if (std::is_same<Context, CPUContext>::value) {
video_id_output->CopyFrom(prefetched_video_id_, &context_);
} else {
video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
}
}
if (get_start_frame_) {
auto* start_frame_output = OperatorBase::Output<Tensor>(index, type);
if (std::is_same<Context, CPUContext>::value) {
start_frame_output->CopyFrom(prefetched_start_frame_, &context_);
} else {
start_frame_output->CopyFrom(
prefetched_start_frame_on_device_, &context_);
}
}
} catch (const std::exception& exc) {
std::cerr << "While calling CopyPrefetched()\n";
std::cerr << exc.what();
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_