145 lines
3.9 KiB
C++
145 lines
3.9 KiB
C++
#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
|
|
#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/types.h"
|
|
#include "caffe2/utils/math.h"
|
|
#include "c10/util/irange.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
using T = float;
|
|
using TInd = int;
|
|
using Engine = DefaultEngine;
|
|
|
|
template <class Context>
|
|
class ConcatBatchMatMulBatchGatherOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
ConcatBatchMatMulBatchGatherOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws) {}
|
|
|
|
bool RunOnDevice() override;
|
|
|
|
protected:
|
|
int axis_ = 1;
|
|
int add_axis_ = 1;
|
|
|
|
bool trans_a_ = 0;
|
|
bool trans_b_ = 1;
|
|
bool broadcast_ = 0;
|
|
};
|
|
|
|
template <class Context>
|
|
bool ConcatBatchMatMulBatchGatherOp<Context>::RunOnDevice() {
|
|
auto& indices = Input(0);
|
|
auto& input_zero = Input(1);
|
|
int adj_size = input_zero.dim() + 1;
|
|
int canonical_axis = 1;
|
|
CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
|
|
for (const auto i : c10::irange(2, InputSize())) {
|
|
CAFFE_ENFORCE(
|
|
Input(i).dtype() == input_zero.dtype(),
|
|
"All inputs must have the same type, expected: ",
|
|
input_zero.dtype().name(),
|
|
" but got: ",
|
|
Input(i).dtype().name(),
|
|
" for input: ",
|
|
i);
|
|
}
|
|
|
|
int before = 1, after = 1;
|
|
for (const auto i : c10::irange(input_zero.dim())) {
|
|
int dim = input_zero.dim32(i);
|
|
if (i < canonical_axis) {
|
|
before *= dim;
|
|
} else { // i > canonical_axis || i == canonical_axis && add_axis_
|
|
after *= dim;
|
|
}
|
|
// check the input dims are compatible.
|
|
for (const auto j : c10::irange(2, InputSize())) {
|
|
int dim_j = Input(j).dim32(i);
|
|
CAFFE_ENFORCE(
|
|
dim == dim_j,
|
|
"Expect dimension = ",
|
|
dim,
|
|
" got ",
|
|
dim_j,
|
|
" at axis = ",
|
|
i,
|
|
" for input: ",
|
|
j,
|
|
". The input tensors can only have different dimensions "
|
|
"when arg 'add_axis' = 0 and along the axis = ",
|
|
canonical_axis,
|
|
" <",
|
|
input_zero.sizes(),
|
|
"> vs <",
|
|
Input(j).sizes(),
|
|
">.");
|
|
}
|
|
}
|
|
|
|
auto ndata = InputSize() - 1;
|
|
auto batch_size = before;
|
|
auto embed_size = after;
|
|
auto gather_size = indices.sizes()[0];
|
|
|
|
vector<int64_t> output_dims;
|
|
output_dims.push_back(batch_size);
|
|
output_dims.insert(
|
|
output_dims.begin() + 1, indices.sizes().begin(), indices.sizes().end());
|
|
auto* output = Output(0, output_dims, at::dtype<T>());
|
|
// std::stringstream ss;
|
|
// ss << "[";
|
|
// for (const auto i : c10::irange(output_dims.size()))ss << output_dims[i];
|
|
// ss << "]";
|
|
// LOG(INFO) << "output size: " << ss.str();
|
|
|
|
auto* output_data = output->template mutable_data<T>();
|
|
auto* indices_data = indices.template data<TInd>();
|
|
#pragma omp parallel
|
|
{
|
|
std::vector<T> scratch_input(ndata * embed_size);
|
|
std::vector<T> scratch_output(ndata * ndata);
|
|
|
|
#pragma omp for
|
|
for (int b = 0; b < batch_size; ++b) {
|
|
// concat input to scratch
|
|
for (const auto i : c10::irange(1, InputSize())) {
|
|
auto* input_data = Input(i).template data<T>();
|
|
memcpy(
|
|
&scratch_input[(i - 1) * embed_size],
|
|
input_data + b * embed_size,
|
|
embed_size * Input(i).itemsize());
|
|
}
|
|
// call mkl gemm
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
CblasTrans,
|
|
ndata,
|
|
ndata,
|
|
embed_size,
|
|
1,
|
|
&scratch_input[0],
|
|
&scratch_input[0],
|
|
0,
|
|
&scratch_output[0],
|
|
&context_);
|
|
// do gather
|
|
|
|
int64_t output_offset = b * gather_size;
|
|
for (const auto i : c10::irange(gather_size)) {
|
|
output_data[output_offset + i] = scratch_output[indices_data[i]];
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
|