293 lines
8.8 KiB
C++
293 lines
8.8 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/Config.h>
|
|
#include <ATen/Dispatch.h>
|
|
#include <ATen/NamedTensorUtils.h>
|
|
#include <ATen/SparseTensorImpl.h>
|
|
#include <ATen/native/SparseTensorUtils.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <ATen/native/StridedRandomAccessor.h>
|
|
#include <ATen/native/CompositeRandomAccessor.h>
|
|
#include <c10/util/irange.h>
|
|
#include <unordered_map>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/empty_like_native.h>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
|
|
using namespace at::sparse;
|
|
|
|
/*
|
|
This is an implementation of the SMMP algorithm:
|
|
"Sparse Matrix Multiplication Package (SMMP)"
|
|
|
|
Randolph E. Bank and Craig C. Douglas
|
|
https://doi.org/10.1007/BF02070824
|
|
*/
|
|
namespace {
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
|
|
/*
|
|
Expands a compressed row pointer into a row indices array
|
|
Inputs:
|
|
`n_row` is the number of rows in `Ap`
|
|
`Ap` is the row pointer
|
|
|
|
Output:
|
|
`Bi` is the row indices
|
|
*/
|
|
for (const auto i : c10::irange(n_row)) {
|
|
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
|
|
Bi[jj] = i;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename index_t_ptr = int64_t*>
|
|
int64_t _csr_matmult_maxnnz(
|
|
const int64_t n_row,
|
|
const int64_t n_col,
|
|
const index_t_ptr Ap,
|
|
const index_t_ptr Aj,
|
|
const index_t_ptr Bp,
|
|
const index_t_ptr Bj) {
|
|
/*
|
|
Compute needed buffer size for matrix `C` in `C = A@B` operation.
|
|
|
|
The matrices should be in proper CSR structure, and their dimensions
|
|
should be compatible.
|
|
*/
|
|
std::vector<int64_t> mask(n_col, -1);
|
|
int64_t nnz = 0;
|
|
for (const auto i : c10::irange(n_row)) {
|
|
int64_t row_nnz = 0;
|
|
|
|
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
|
|
int64_t j = Aj[jj];
|
|
for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
|
|
int64_t k = Bj[kk];
|
|
if (mask[k] != i) {
|
|
mask[k] = i;
|
|
row_nnz++;
|
|
}
|
|
}
|
|
}
|
|
int64_t next_nnz = nnz + row_nnz;
|
|
nnz = next_nnz;
|
|
}
|
|
return nnz;
|
|
}
|
|
|
|
template<typename index_t_ptr, typename scalar_t_ptr>
|
|
void _csr_matmult(
|
|
const int64_t n_row,
|
|
const int64_t n_col,
|
|
const index_t_ptr Ap,
|
|
const index_t_ptr Aj,
|
|
const scalar_t_ptr Ax,
|
|
const index_t_ptr Bp,
|
|
const index_t_ptr Bj,
|
|
const scalar_t_ptr Bx,
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
typename index_t_ptr::value_type Cp[],
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
typename index_t_ptr::value_type Cj[],
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
typename scalar_t_ptr::value_type Cx[]) {
|
|
/*
|
|
Compute CSR entries for matrix C = A@B.
|
|
|
|
The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
|
|
should be compatible.
|
|
|
|
Inputs:
|
|
`n_row` - number of row in A
|
|
`n_col` - number of columns in B
|
|
`Ap[n_row+1]` - row pointer
|
|
`Aj[nnz(A)]` - column indices
|
|
`Ax[nnz(A)] - nonzeros
|
|
`Bp[?]` - row pointer
|
|
`Bj[nnz(B)]` - column indices
|
|
`Bx[nnz(B)]` - nonzeros
|
|
Outputs:
|
|
`Cp[n_row+1]` - row pointer
|
|
`Cj[nnz(C)]` - column indices
|
|
`Cx[nnz(C)]` - nonzeros
|
|
|
|
Note:
|
|
Output arrays Cp, Cj, and Cx must be preallocated
|
|
*/
|
|
using index_t = typename index_t_ptr::value_type;
|
|
using scalar_t = typename scalar_t_ptr::value_type;
|
|
|
|
std::vector<index_t> next(n_col, -1);
|
|
std::vector<scalar_t> sums(n_col, 0);
|
|
|
|
int64_t nnz = 0;
|
|
|
|
Cp[0] = 0;
|
|
|
|
for (const auto i : c10::irange(n_row)) {
|
|
index_t head = -2;
|
|
index_t length = 0;
|
|
|
|
index_t jj_start = Ap[i];
|
|
index_t jj_end = Ap[i + 1];
|
|
for (const auto jj : c10::irange(jj_start, jj_end)) {
|
|
index_t j = Aj[jj];
|
|
scalar_t v = Ax[jj];
|
|
|
|
index_t kk_start = Bp[j];
|
|
index_t kk_end = Bp[j + 1];
|
|
for (const auto kk : c10::irange(kk_start, kk_end)) {
|
|
index_t k = Bj[kk];
|
|
|
|
sums[k] += v * Bx[kk];
|
|
|
|
if (next[k] == -1) {
|
|
next[k] = head;
|
|
head = k;
|
|
length++;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto jj : c10::irange(length)) {
|
|
(void)jj; //Suppress unused variable warning
|
|
|
|
// NOTE: the linked list that encodes col indices
|
|
// is not guaranteed to be sorted.
|
|
Cj[nnz] = head;
|
|
Cx[nnz] = sums[head];
|
|
nnz++;
|
|
|
|
index_t temp = head;
|
|
head = next[head];
|
|
|
|
next[temp] = -1; // clear arrays
|
|
sums[temp] = 0;
|
|
}
|
|
|
|
// Make sure that col indices are sorted.
|
|
// TODO: a better approach is to implement a CSR @ CSC kernel.
|
|
// NOTE: Cx arrays are expected to be contiguous!
|
|
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
|
|
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
|
|
auto kv_accessor = CompositeRandomAccessorCPU<
|
|
decltype(col_indices_accessor), decltype(val_accessor)
|
|
>(col_indices_accessor, val_accessor);
|
|
std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool {
|
|
return get<0>(lhs) < get<0>(rhs);
|
|
});
|
|
|
|
Cp[i + 1] = nnz;
|
|
}
|
|
}
|
|
|
|
|
|
template <typename scalar_t>
|
|
void sparse_matmul_kernel(
|
|
Tensor& output,
|
|
const Tensor& mat1,
|
|
const Tensor& mat2) {
|
|
/*
|
|
Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
|
|
*/
|
|
|
|
auto M = mat1.size(0);
|
|
auto N = mat2.size(1);
|
|
|
|
const auto mat1_csr = mat1.to_sparse_csr();
|
|
const auto mat2_csr = mat2.to_sparse_csr();
|
|
|
|
auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
|
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
|
mat1_csr.crow_indices().stride(-1));
|
|
auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
|
mat1_csr.col_indices().data_ptr<int64_t>(),
|
|
mat1_csr.col_indices().stride(-1));
|
|
auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
|
|
mat1_csr.values().data_ptr<scalar_t>(),
|
|
mat1_csr.values().stride(-1));
|
|
auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
|
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
|
mat2_csr.crow_indices().stride(-1));
|
|
auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
|
mat2_csr.col_indices().data_ptr<int64_t>(),
|
|
mat2_csr.col_indices().stride(-1));
|
|
auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
|
|
mat2_csr.values().data_ptr<scalar_t>(),
|
|
mat2_csr.values().stride(-1));
|
|
|
|
const auto nnz = _csr_matmult_maxnnz(
|
|
M,
|
|
N,
|
|
mat1_crow_indices_ptr,
|
|
mat1_col_indices_ptr,
|
|
mat2_crow_indices_ptr,
|
|
mat2_col_indices_ptr);
|
|
|
|
auto output_indices = output._indices();
|
|
auto output_values = output._values();
|
|
|
|
Tensor output_indptr = at::empty({M + 1}, kLong);
|
|
at::native::resize_output(output_indices, {2, nnz});
|
|
at::native::resize_output(output_values, nnz);
|
|
|
|
Tensor output_row_indices = output_indices.select(0, 0);
|
|
Tensor output_col_indices = output_indices.select(0, 1);
|
|
|
|
// TODO: replace with a CSR @ CSC kernel for better performance.
|
|
_csr_matmult(
|
|
M,
|
|
N,
|
|
mat1_crow_indices_ptr,
|
|
mat1_col_indices_ptr,
|
|
mat1_values_ptr,
|
|
mat2_crow_indices_ptr,
|
|
mat2_col_indices_ptr,
|
|
mat2_values_ptr,
|
|
output_indptr.data_ptr<int64_t>(),
|
|
output_col_indices.data_ptr<int64_t>(),
|
|
output_values.data_ptr<scalar_t>());
|
|
|
|
csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
|
|
output._coalesced_(true);
|
|
}
|
|
|
|
} // end anonymous namespace
|
|
|
|
Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
|
|
TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
|
|
TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
|
|
TORCH_CHECK(mat1_.dim() == 2);
|
|
TORCH_CHECK(mat2_.dim() == 2);
|
|
TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values");
|
|
TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values");
|
|
|
|
TORCH_CHECK(
|
|
mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
|
|
mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
|
|
|
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
|
"mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
|
|
|
|
auto output = at::native::empty_like(mat1_);
|
|
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
|
|
|
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
|
sparse_matmul_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
|
});
|
|
return output;
|
|
}
|
|
|
|
|
|
} // namespace at::native
|