#include "caffe2/operators/tile_op.h" #include #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" namespace caffe2 { namespace { template __global__ void TileCopyCUDAKernel( const int total_size, const int inner_size, const int tiles, const T* X, T* Y) { const int x = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; if (x < total_size) { const int r = x / inner_size / tiles; const int c = x % inner_size; #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) Y[x] = __ldg(X + r * inner_size + c); #else Y[x] = X[r * inner_size + c]; #endif } } } // namespace template <> template bool TileOp::DoTile( const int outer_size, const int inner_size, const T* X, T* Y) { const std::int64_t total_size = static_cast(outer_size) * static_cast(tiles_) * static_cast(inner_size); const int M = math::DivUp(total_size, CAFFE_CUDA_NUM_THREADS); TileCopyCUDAKernel <<>>( total_size, inner_size, tiles_, X, Y); C10_CUDA_KERNEL_LAUNCH_CHECK(); return true; } template <> template bool TileGradientOp::DoTileGradient( const int outer_size, const int inner_size, const T* dY, T* dX) { const std::array dY_dims = {outer_size, tiles_, inner_size}; const std::array dX_dims = {outer_size, 1, inner_size}; math::ReduceSum( 3, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_); return true; } template <> template <> bool TileGradientOp::DoTileGradient( const int outer_size, const int inner_size, const float* dY, float* dX) { if (inner_size == 1) { const std::array dY_dims = {outer_size, tiles_}; const std::array dX_dims = {outer_size, 1}; math::ReduceSum( 2, dY_dims.data(), dX_dims.data(), 1.0f, dY, dX, &context_); } else { ReinitializeTensor(&ones_, tiles_, at::dtype().device(CUDA)); math::Set( tiles_, 1.0f, ones_.template mutable_data(), &context_); math::GemmStridedBatched( CblasTrans, CblasNoTrans, outer_size, inner_size, 1, tiles_, 1.0f, dY, tiles_ * inner_size, ones_.template data(), 0, 0.0f, dX, inner_size, &context_); } return true; } REGISTER_CUDA_OPERATOR(Tile, TileOp); REGISTER_CUDA_OPERATOR(TileGradient, TileGradientOp); } // namespace caffe2