52 lines
1.0 KiB
C++
52 lines
1.0 KiB
C++
#pragma once
|
|
|
|
#include <ATen/Tensor.h>
|
|
#include <ATen/core/Scalar.h>
|
|
|
|
namespace at {
|
|
namespace native {
|
|
namespace sparse {
|
|
namespace impl {
|
|
|
|
TORCH_API Tensor& _compressed_row_strided_mm_out(
|
|
const Tensor& compressed_row_sparse,
|
|
const Tensor& strided,
|
|
Tensor& result);
|
|
|
|
TORCH_API Tensor& _compressed_row_strided_addmm_out(
|
|
const Tensor& self,
|
|
const Tensor& mat1,
|
|
const Tensor& mat2,
|
|
const Scalar& beta,
|
|
const Scalar& alpha,
|
|
Tensor& result);
|
|
|
|
namespace cpu {
|
|
|
|
void addmv_out_sparse_csr(
|
|
const Tensor& mat,
|
|
const Tensor& vec,
|
|
const Scalar& beta,
|
|
const Scalar& alpha,
|
|
const Tensor& result);
|
|
|
|
void add_out_sparse_csr(
|
|
const Tensor& mat1,
|
|
const Tensor& mat2,
|
|
const Scalar& alpha,
|
|
const Tensor& result);
|
|
|
|
void triangular_solve_out_sparse_csr(
|
|
const Tensor& A,
|
|
const Tensor& B,
|
|
const Tensor& X,
|
|
bool upper,
|
|
bool transpose,
|
|
bool unitriangular);
|
|
|
|
} // namespace cpu
|
|
} // namespace impl
|
|
} // namespace sparse
|
|
} // namespace native
|
|
} // namespace at
|