30 lines
666 B
Plaintext
30 lines
666 B
Plaintext
#include "caffe2/operators/batch_matmul_op.h"
|
|
|
|
#include "caffe2/core/context_gpu.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <>
|
|
bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() {
|
|
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>);
|
|
|
|
|
|
#if !defined(USE_ROCM)
|
|
|
|
template <>
|
|
bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() {
|
|
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
|
|
BatchMatMul,
|
|
TENSORCORE,
|
|
BatchMatMulOp<CUDAContext, TensorCoreEngine>);
|
|
|
|
#endif
|
|
|
|
} // namespace caffe2
|