pytorch/caffe2/operators/gather_op.cu

23 lines
646 B
Plaintext

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/gather_op.h"
#include "caffe2/operators/gather_op.cuh"
namespace caffe2 {
template <>
bool GatherOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, OperatorBase::Input<Tensor>(INDICES, CUDA));
}
template <>
template <typename Index>
bool GatherOp<CUDAContext>::DoRunWithType() {
// Use shared implementation with BatchGather
return gather_helper::gather_impl_cuda<Index>(
this, DATA, INDICES, 0, axis_, wrap_indices_, match_outer_);
}
REGISTER_CUDA_OPERATOR(Gather, GatherOp<CUDAContext>);
} // namespace caffe2