909 lines
36 KiB
Python
909 lines
36 KiB
Python
# Owner(s): ["module: masked operators"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
make_tensor,
|
|
parametrize,
|
|
instantiate_parametrized_tests,
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import (
|
|
SampleInput,
|
|
binary_ufuncs,
|
|
reduction_ops,
|
|
unary_ufuncs,
|
|
)
|
|
|
|
from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
|
|
from torch.masked.maskedtensor.core import _masks_match, _tensors_match
|
|
from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
|
|
from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
|
|
from torch.masked.maskedtensor.reductions import REDUCE_NAMES
|
|
|
|
|
|
def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
|
|
mask = mt_result.get_mask()
|
|
mt_result_data = mt_result.get_data()
|
|
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
|
|
mask = mask.to_dense()
|
|
if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}:
|
|
mt_result_data = mt_result_data.to_dense()
|
|
a = mt_result_data.detach().masked_fill_(~mask, 0)
|
|
b = t_result.detach().masked_fill_(~mask, 0)
|
|
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
|
|
raise ValueError("The data in MaskedTensor a and Tensor b do not match")
|
|
|
|
def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
|
|
mt_data1 = mt1.get_data()
|
|
mt_data2 = mt2.get_data()
|
|
if mt_data1.layout != mt_data2.layout:
|
|
raise ValueError("mt1's data and mt2's data do not have the same layout. "
|
|
f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}")
|
|
|
|
mask = mt1.get_mask()
|
|
mask2 = mt2.get_mask()
|
|
if not _masks_match(mt1, mt2):
|
|
raise ValueError("mt1 and mt2 must have matching masks")
|
|
if mask.layout != mask2.layout:
|
|
raise ValueError("mt1's mask and mt2's mask do not have the same layout. "
|
|
f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}")
|
|
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
|
|
mask = mask.to_dense()
|
|
|
|
if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}:
|
|
mt_data1 = mt_data1.to_dense()
|
|
mt_data2 = mt_data2.to_dense()
|
|
a = mt_data1.detach().masked_fill_(~mask, 0)
|
|
b = mt_data2.detach().masked_fill_(~mask, 0)
|
|
|
|
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
|
|
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")
|
|
|
|
|
|
def _create_random_mask(shape, device):
|
|
return make_tensor(shape, device=device, dtype=torch.bool)
|
|
|
|
def _generate_sample_data(
|
|
device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
|
|
):
|
|
assert layout in {
|
|
torch.strided,
|
|
torch.sparse_coo,
|
|
torch.sparse_csr,
|
|
}, "Layout must be strided/sparse_coo/sparse_csr"
|
|
shapes = [
|
|
[],
|
|
[2],
|
|
[3, 5],
|
|
[3, 2, 1, 2],
|
|
]
|
|
inputs = []
|
|
for s in shapes:
|
|
data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type]
|
|
mask = _create_random_mask(s, device)
|
|
if layout == torch.sparse_coo:
|
|
mask = mask.to_sparse_coo().coalesce()
|
|
data = data.sparse_mask(mask).requires_grad_(requires_grad)
|
|
elif layout == torch.sparse_csr:
|
|
if data.ndim != 2 and mask.ndim != 2:
|
|
continue
|
|
mask = mask.to_sparse_csr()
|
|
data = data.sparse_mask(mask)
|
|
inputs.append(SampleInput(data, kwargs={"mask": mask}))
|
|
return inputs
|
|
|
|
def _fix_fn_name(fn_name):
|
|
if fn_name[-1] == "_":
|
|
fn_name = fn_name[:-1]
|
|
return fn_name
|
|
|
|
|
|
class TestBasics(TestCase):
|
|
def test_invalid_tensor_inputs(self, device):
|
|
data = torch.randn((3, 4), device=device)
|
|
mask = _create_random_mask((3, 4), device=device)
|
|
mt = masked_tensor(data, mask)
|
|
|
|
with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
|
|
masked_tensor(mt, mask)
|
|
with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
|
|
masked_tensor(0, mask)
|
|
with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
|
|
masked_tensor(data, mt)
|
|
with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
|
|
masked_tensor(data, 0)
|
|
|
|
def test_diff_layouts(self, device):
|
|
data = torch.randn((3, 4), device=device).to_sparse_coo()
|
|
mask = _create_random_mask((3, 4), device=device)
|
|
with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"):
|
|
masked_tensor(data, mask)
|
|
|
|
def test_diff_dim(self, device):
|
|
data = torch.randn((3, 4, 5), device=device)
|
|
mask = _create_random_mask((3, 4), device=device)
|
|
with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"):
|
|
masked_tensor(data, mask)
|
|
|
|
def test_diff_sizes(self, device):
|
|
data = torch.randn((3, 4), device=device)
|
|
mask = _create_random_mask((3, 3), device=device)
|
|
with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"):
|
|
masked_tensor(data, mask)
|
|
|
|
def test_grad_warning(self, device):
|
|
data = torch.randn((3, 4), device=device, requires_grad=True)
|
|
mask = _create_random_mask((3, 4), device=device)
|
|
msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad."
|
|
with self.assertWarnsRegex(UserWarning, msg):
|
|
mt = masked_tensor(data, mask)
|
|
|
|
def test_add(self, device):
|
|
data = torch.arange(5.0, device=device)
|
|
mask = torch.tensor([True, True, False, True, False], device=device)
|
|
m0 = masked_tensor(data, mask)
|
|
m1 = masked_tensor(data, ~mask)
|
|
with self.assertRaisesRegex(ValueError, "Input masks must match."):
|
|
m0 + m1
|
|
_compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask))
|
|
|
|
def test_softmax(self, device):
|
|
data = torch.randn((3, 4), device=device) * 0.1
|
|
mask = torch.tensor(
|
|
[
|
|
[True, True, True, False],
|
|
[False, True, False, True],
|
|
[True, True, False, False],
|
|
],
|
|
device=device
|
|
)
|
|
mt = masked_tensor(data, mask, requires_grad=True)
|
|
masked_res = torch.softmax(mt, -1)
|
|
masked_res.sum().backward()
|
|
xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_()
|
|
tensor_res = torch.softmax(xinf, -1)
|
|
tensor_res.sum().backward()
|
|
|
|
_compare_mt_t(masked_res, tensor_res)
|
|
_compare_mt_t(mt.grad, xinf.grad, atol=1e-06)
|
|
|
|
def test_where(self, device):
|
|
data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
|
|
mask = data < 0
|
|
|
|
mx = masked_tensor(data, mask, requires_grad=True)
|
|
my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True)
|
|
masked_res = torch.where(mask, torch.exp(mx), my)
|
|
masked_res.sum().backward()
|
|
|
|
x = data.detach().clone().requires_grad_()
|
|
y = torch.ones_like(x, device=device, requires_grad=True)
|
|
tensor_res = torch.where(mask, torch.exp(x), y)
|
|
tensor_res.sum().backward()
|
|
|
|
_compare_mt_t(masked_res, tensor_res)
|
|
_compare_mt_t(mx.grad, x.grad)
|
|
_compare_mt_t(my.grad, y.grad)
|
|
|
|
def test_to_sparse(self, device):
|
|
for sample in _generate_sample_data(device=device):
|
|
data = sample.input
|
|
mask = sample.kwargs["mask"]
|
|
mt = masked_tensor(data.clone().detach(), mask, requires_grad=True)
|
|
|
|
sparse_mt = mt.to_sparse()
|
|
data.to_sparse().to_dense().sum().backward()
|
|
sparse_mt.to_dense().sum().backward()
|
|
|
|
_compare_mt_t(sparse_mt, data)
|
|
_compare_mt_t(mt.grad, data.grad)
|
|
|
|
def test_to_dense(self, device):
|
|
samples = _generate_sample_data(
|
|
device=device,
|
|
layout=torch.sparse_coo
|
|
) + _generate_sample_data(device=device, layout=torch.sparse_csr)
|
|
for sample in samples:
|
|
data = sample.input
|
|
mask = sample.kwargs["mask"]
|
|
mt = masked_tensor(data, mask, requires_grad=True)
|
|
|
|
dense_data = data.to_dense().detach().clone().requires_grad_(True)
|
|
dense_mt = mt.to_dense()
|
|
dense_data.sum().backward()
|
|
dense_mt.sum().backward()
|
|
|
|
_compare_mt_t(dense_mt, dense_data)
|
|
_compare_mt_t(mt.grad.to_dense(), dense_data.grad)
|
|
|
|
def test_to_dense_and_sparse_coo(self, device):
|
|
for sample in _generate_sample_data(device=device, layout=torch.strided):
|
|
data = sample.input
|
|
mask = sample.kwargs["mask"]
|
|
ms = mask.to_sparse_coo().coalesce()
|
|
|
|
mt = masked_tensor(data, mask, requires_grad=True)
|
|
mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
|
|
|
|
converted = mt.to_sparse().to_dense()
|
|
converted.sum().backward()
|
|
|
|
converted2 = mts.to_dense()
|
|
converted2.sum().backward()
|
|
|
|
_compare_mts(converted, converted2)
|
|
_compare_mts(mt.grad, mts.grad.to_dense())
|
|
|
|
def test_to_dense_and_sparse_csr(self, device):
|
|
for sample in _generate_sample_data(device=device, layout=torch.strided):
|
|
data = sample.input
|
|
mask = sample.kwargs["mask"]
|
|
if data.ndim != 2:
|
|
continue
|
|
ms = mask.to_sparse_csr()
|
|
|
|
mt = masked_tensor(data, mask, requires_grad=True)
|
|
mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
|
|
|
|
converted = mt.to_sparse_csr().to_dense()
|
|
converted.sum().backward()
|
|
|
|
converted2 = mts.to_dense()
|
|
converted2.sum().backward()
|
|
|
|
_compare_mts(converted, converted2)
|
|
_compare_mts(mt.grad, mts.grad.to_dense())
|
|
|
|
def test_invalid_sparse_layout(self, device):
|
|
data = torch.randn((3, 4), device=device).to_sparse_csc()
|
|
mask = _create_random_mask((3, 4), device=device).to_sparse_csc()
|
|
with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"):
|
|
masked_tensor(data, mask)
|
|
|
|
def test_invalid_sparse_coo_values(self, device):
|
|
v = torch.tensor([3, 4, 5], dtype=torch.float32)
|
|
i1 = torch.tensor([[0, 1, 1], [2, 0, 2]])
|
|
i2 = torch.tensor([[0, 1, 1], [2, 1, 2]])
|
|
|
|
t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device)
|
|
mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device)
|
|
|
|
msg = "data and mask are both sparse COO tensors but do not have the same indices."
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
masked_tensor(t, mask)
|
|
|
|
def test_invalid_sparse_csr_values(self, device):
|
|
crow_indices1 = [0, 2, 3]
|
|
crow_indices2 = [0, 1, 3]
|
|
col_indices1 = [0, 1, 2]
|
|
col_indices2 = [1, 2, 3]
|
|
|
|
values = [2, 3, 4]
|
|
mask_values = [True, True, True]
|
|
|
|
t1 = torch.sparse_csr_tensor(
|
|
torch.tensor(crow_indices1, dtype=torch.int64),
|
|
torch.tensor(col_indices1, dtype=torch.int64),
|
|
torch.tensor(values),
|
|
size=(2, 4)
|
|
)
|
|
mask1 = torch.sparse_csr_tensor(
|
|
torch.tensor(crow_indices2, dtype=torch.int64),
|
|
torch.tensor(col_indices1, dtype=torch.int64),
|
|
torch.tensor(mask_values),
|
|
dtype=torch.bool,
|
|
size=(2, 4),
|
|
)
|
|
t2 = torch.sparse_csr_tensor(
|
|
torch.tensor(crow_indices2, dtype=torch.int64),
|
|
torch.tensor(col_indices1, dtype=torch.int64),
|
|
torch.tensor(values),
|
|
size=(2, 4),
|
|
)
|
|
mask2 = torch.sparse_csr_tensor(
|
|
torch.tensor(crow_indices2, dtype=torch.int64),
|
|
torch.tensor(col_indices2, dtype=torch.int64),
|
|
torch.tensor(mask_values),
|
|
dtype=torch.bool,
|
|
size=(2, 4),
|
|
)
|
|
|
|
msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices."
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
masked_tensor(t1, mask1)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
masked_tensor(t2, mask2)
|
|
|
|
def test_contiguous(self, device):
|
|
data = torch.randn((3, 3), device=device)
|
|
|
|
contiguous_data = data.clone()
|
|
mask1 = (contiguous_data > 0).bool()
|
|
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
|
|
mask2 = (not_contiguous_data > 0).bool()
|
|
|
|
contiguous_mt = masked_tensor(contiguous_data, mask1)
|
|
not_contiguous_mt = masked_tensor(not_contiguous_data, mask2)
|
|
|
|
contiguous_mt_sparse = masked_tensor(
|
|
contiguous_data.to_sparse_coo(), mask1.to_sparse_coo()
|
|
)
|
|
not_contiguous_mt_sparse = masked_tensor(
|
|
not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo()
|
|
)
|
|
|
|
self.assertEqual(contiguous_data.is_contiguous(), True)
|
|
self.assertEqual(not_contiguous_data.is_contiguous(), False)
|
|
|
|
self.assertEqual(contiguous_mt.is_contiguous(), True)
|
|
self.assertEqual(not_contiguous_mt.is_contiguous(), False)
|
|
|
|
error_msg = "MaskedTensors with sparse data do not have is_contiguous"
|
|
for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]:
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
t.is_contiguous()
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
t.contiguous()
|
|
|
|
now_contiguous_mt = not_contiguous_mt.contiguous()
|
|
|
|
_compare_mts(not_contiguous_mt, now_contiguous_mt)
|
|
|
|
self.assertEqual(now_contiguous_mt.is_contiguous(), True)
|
|
self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True)
|
|
self.assertEqual(now_contiguous_mt.is_contiguous(), True)
|
|
|
|
class TestUnary(TestCase):
|
|
def _get_test_data(self, fn_name):
|
|
data = torch.randn(10, 10)
|
|
mask = torch.rand(10, 10) > 0.5
|
|
fn_name = _fix_fn_name(fn_name)
|
|
if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]:
|
|
data = data.mul(0.5).abs()
|
|
if fn_name in ["rsqrt"]:
|
|
data = data.abs() + 1 # Void division by zero
|
|
if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]:
|
|
data = data.abs().mul(0.5).clamp(0, 1)
|
|
if fn_name in ["atanh", "arctanh", "erfinv"]:
|
|
data = data.mul(0.5).clamp(-1, 1)
|
|
if fn_name in ["acosh", "arccosh"]:
|
|
data = data.abs() + 1
|
|
if fn_name in ["bitwise_not"]:
|
|
data = data.mul(128).to(torch.int8)
|
|
return data, mask
|
|
|
|
def _get_sample_kwargs(self, fn_name):
|
|
fn_name = _fix_fn_name(fn_name)
|
|
kwargs = {}
|
|
if fn_name in ["clamp", "clip"]:
|
|
kwargs["min"] = -0.5
|
|
kwargs["max"] = 0.5
|
|
return kwargs
|
|
|
|
def _get_sample_args(self, fn_name, data, mask):
|
|
fn_name = _fix_fn_name(fn_name)
|
|
mt = masked_tensor(data, mask)
|
|
t_args = [data]
|
|
mt_args = [mt]
|
|
if fn_name in ["pow"]:
|
|
t_args += [2.0]
|
|
mt_args += [2.0]
|
|
return t_args, mt_args
|
|
|
|
@parametrize("fn", NATIVE_UNARY_FNS)
|
|
def test_unary(self, fn):
|
|
torch.random.manual_seed(0)
|
|
fn_name = fn.__name__
|
|
data, mask = self._get_test_data(fn_name)
|
|
kwargs = self._get_sample_kwargs(fn_name)
|
|
|
|
t_args, mt_args = self._get_sample_args(fn_name, data, mask)
|
|
|
|
mt_result = fn(*mt_args, **kwargs)
|
|
t_result = fn(*t_args, **kwargs)
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
@parametrize("fn", NATIVE_INPLACE_UNARY_FNS)
|
|
def test_inplace_unary(self, fn):
|
|
torch.random.manual_seed(0)
|
|
fn_name = fn.__name__
|
|
data, mask = self._get_test_data(fn_name)
|
|
kwargs = self._get_sample_kwargs(fn_name)
|
|
|
|
t_args, mt_args = self._get_sample_args(fn_name, data, mask)
|
|
|
|
mt_result = fn(*mt_args, **kwargs)
|
|
t_result = fn(*t_args, **kwargs)
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
class TestBinary(TestCase):
|
|
def _get_test_data(self, fn_name):
|
|
fn_name = _fix_fn_name(fn_name)
|
|
data0 = torch.randn(10, 10)
|
|
data1 = torch.randn(10, 10)
|
|
mask = torch.rand(10, 10) > 0.5
|
|
if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
|
|
data0 = data0.mul(128).to(torch.int8)
|
|
data1 = data1.mul(128).to(torch.int8)
|
|
if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]:
|
|
data0 = data0.abs().to(torch.int64)
|
|
data1 = data1.abs().to(torch.int64)
|
|
return data0, data1, mask
|
|
|
|
def _get_sample_kwargs(self, fn_name):
|
|
fn_name = _fix_fn_name(fn_name)
|
|
kwargs = {}
|
|
return kwargs
|
|
|
|
def _yield_sample_args(self, fn_name, data0, data1, mask):
|
|
""" Returns two sets of Tensor and MaskedTensor args for a binary function to compute.
|
|
Tensor args are all the same (just the two provided data tensors),
|
|
while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
|
|
"""
|
|
fn_name = _fix_fn_name(fn_name)
|
|
mt0 = masked_tensor(data0, mask)
|
|
mt1 = masked_tensor(data1, mask)
|
|
|
|
t_args = [data0, data1]
|
|
mt_args = [mt0, mt1]
|
|
yield t_args, mt_args
|
|
|
|
t_args = [data0, data1]
|
|
mt_args = [mt0, data1]
|
|
yield t_args, mt_args
|
|
|
|
@parametrize("fn", NATIVE_BINARY_FNS)
|
|
def test_binary(self, fn):
|
|
torch.random.manual_seed(0)
|
|
fn_name = fn.__name__
|
|
data0, data1, mask = self._get_test_data(fn_name)
|
|
kwargs = self._get_sample_kwargs(fn_name)
|
|
|
|
for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
|
|
mt_result = fn(*mt_args, **kwargs)
|
|
t_result = fn(*t_args, **kwargs)
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
@parametrize("fn", NATIVE_INPLACE_BINARY_FNS)
|
|
def test_inplace_binary(self, fn):
|
|
torch.random.manual_seed(0)
|
|
fn_name = fn.__name__
|
|
data0, data1, mask = self._get_test_data(fn_name)
|
|
kwargs = self._get_sample_kwargs(fn_name)
|
|
|
|
for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
|
|
mt_result = fn(*mt_args, **kwargs)
|
|
t_result = fn(*t_args, **kwargs)
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
@parametrize("fn_name", ["add", "add_"])
|
|
def test_masks_match(self, fn_name):
|
|
torch.random.manual_seed(0)
|
|
fn = getattr(torch.ops.aten, fn_name)
|
|
data0, data1, mask = self._get_test_data(fn_name)
|
|
mask0 = mask
|
|
mask1 = torch.rand(mask.size()) > 0.5
|
|
mt0 = masked_tensor(data0, mask0)
|
|
mt1 = masked_tensor(data1, mask1)
|
|
try:
|
|
fn(mt0, mt1)
|
|
raise AssertionError()
|
|
except ValueError as e:
|
|
assert (
|
|
"Input masks must match. If you need support for this, please open an issue on Github."
|
|
== str(e)
|
|
)
|
|
|
|
class TestReductions(TestCase):
|
|
def test_max_not_implemented(self):
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m)
|
|
with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"):
|
|
mt.max()
|
|
|
|
def test_sum(self):
|
|
d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([0.0, 4.0, 1.0, 13]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.sum(dim=0),
|
|
)
|
|
|
|
def test_sum_grad(self):
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.sum().backward()
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m))
|
|
|
|
def test_mean(self):
|
|
d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([0.0, 4.0, 1.0, 3]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.mean(dim=0),
|
|
)
|
|
|
|
"""
|
|
The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
|
|
the two different ways of constructing MaskedTensors:
|
|
masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
|
|
as_masked_tensor(data, mask) -- differentiable constructor
|
|
|
|
Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True
|
|
as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data
|
|
|
|
Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
|
|
|
|
Assuming mt.mean().backward() is run after each constructor:
|
|
|
|
Case 1a:
|
|
values.requires_grad = True
|
|
mt = masked_tensor(values, mask, requires_grad=True)
|
|
yields
|
|
- Provide a UserWarning because values.requires_grad=True
|
|
- values.grad = None
|
|
- mt.grad is a MaskedTensor with the correct gradient
|
|
|
|
Case 1b:
|
|
values.requires_grad = False
|
|
mt = masked_tensor(values, mask, requires_grad=True)
|
|
yields
|
|
- values.grad = None
|
|
- mt.grad is a MaskedTensor with the correct gradient
|
|
|
|
Case 2a/2b:
|
|
values.requires_grad = True/False
|
|
mt = masked_tensor(values, mask, requires_grad=False)
|
|
|
|
will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
|
|
as expected. When values.requires_grad=True, we will also get a UserWarning
|
|
|
|
Case 3a:
|
|
values.requires_grad = True
|
|
mt = as_masked_tensor(values, mask)
|
|
yields
|
|
- values.grad is a MaskedTensor with the correct gradient
|
|
- mt.grad is None and gives a UserWarning that
|
|
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
|
|
|
|
Case 3b:
|
|
values.requires_grad = False
|
|
mt = as_masked_tensor(values, mask)
|
|
|
|
will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
|
|
as expected.
|
|
"""
|
|
def test_mean_grad_case_1a(self):
|
|
""" values.requires_grad = True
|
|
mt = masked_tensor(values, mask, requires_grad=True)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.mean().backward()
|
|
self.assertIsNone(d.grad)
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
|
|
|
|
def test_mean_grad_case_1b(self):
|
|
""" values.requires_grad = False
|
|
mt = masked_tensor(values, mask, requires_grad=True)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.mean().backward()
|
|
self.assertIsNone(d.grad)
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
|
|
|
|
def test_mean_grad_case_1c(self):
|
|
""" values.requires_grad = True
|
|
mt = masked_tensor(values, mask, requires_grad=False)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
|
|
mt = masked_tensor(d, m, requires_grad=False)
|
|
result = mt.mean()
|
|
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
result.backward()
|
|
|
|
|
|
def test_mean_grad_case_1d(self):
|
|
""" values.requires_grad = False
|
|
mt = masked_tensor(values, mask, requires_grad=False)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=False)
|
|
result = mt.mean()
|
|
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
result.backward()
|
|
|
|
def test_mean_grad_case_1e(self):
|
|
""" values.requires_grad = True
|
|
mt = as_masked_tensor(values, mask)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = as_masked_tensor(d, m)
|
|
mt.mean().backward()
|
|
_compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
|
|
msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
|
|
with self.assertWarnsRegex(UserWarning, msg):
|
|
self.assertIsNone(mt.grad)
|
|
|
|
def test_mean_grad_case_1f(self):
|
|
""" values.requires_grad = False
|
|
mt = as_masked_tensor(values, mask)
|
|
"""
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = as_masked_tensor(d, m)
|
|
result = mt.mean()
|
|
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
result.backward()
|
|
|
|
def test_mean_dim_grad(self):
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, True, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.mean(1).sum().backward()
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m))
|
|
|
|
def test_amax(self):
|
|
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([0.0, -4.0, 1.0, 3]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.amax(dim=0),
|
|
)
|
|
|
|
def test_amax_grad(self):
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.amax().backward()
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m))
|
|
|
|
def test_amin(self):
|
|
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([0.0, -4.0, 1.0, -3]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.amin(dim=0),
|
|
)
|
|
|
|
def test_amin_grad(self):
|
|
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.amin().backward()
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m))
|
|
|
|
def test_prod(self):
|
|
d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([0.0, 4.0, 1.0, 0.0]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.prod(dim=0),
|
|
)
|
|
|
|
def test_prod_grad(self):
|
|
d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
mt = masked_tensor(d, m, requires_grad=True)
|
|
mt.prod().backward()
|
|
_compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m))
|
|
|
|
def test_all(self):
|
|
d = torch.tensor([[True, True, False, False], [False, True, True, True]])
|
|
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all())
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([True, True, True, False]),
|
|
torch.tensor([True, True, False, True]),
|
|
),
|
|
mt.all(dim=0),
|
|
)
|
|
|
|
m = torch.tensor([[True, False, True, False], [False, True, False, False]])
|
|
mt = masked_tensor(d, m)
|
|
_compare_mts(
|
|
masked_tensor(
|
|
torch.tensor([True, True, False, True]),
|
|
torch.tensor([True, True, True, False]),
|
|
),
|
|
mt.all(dim=0),
|
|
)
|
|
|
|
def test_grad_dtype(self):
|
|
d = torch.tensor([[True, True, False], [False, True, True]])
|
|
m = torch.tensor([[True, False, False], [False, True, False]])
|
|
msg = "Only Tensors of floating point and complex dtype can require gradients"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
masked_tensor(d, m, requires_grad=True)
|
|
|
|
|
|
def is_unary(op):
|
|
return op.name in UNARY_NAMES
|
|
|
|
def is_binary(op):
|
|
return op.name in BINARY_NAMES
|
|
|
|
def is_reduction(op):
|
|
return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"}
|
|
|
|
mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)]
|
|
mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)]
|
|
mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)]
|
|
|
|
MASKEDTENSOR_FLOAT_TYPES = {
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.float64,
|
|
}
|
|
|
|
class TestOperators(TestCase):
|
|
def _convert_mt_args(self, args, mask, layout):
|
|
return [
|
|
masked_tensor(
|
|
arg.sparse_mask(mask) if layout != torch.strided else arg, mask
|
|
)
|
|
if torch.is_tensor(arg)
|
|
else arg
|
|
for arg in args
|
|
]
|
|
|
|
def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
for sample in samples:
|
|
input = sample.input
|
|
sample_args, sample_kwargs = sample.args, sample.kwargs
|
|
mask = (
|
|
_create_random_mask(input.shape, device)
|
|
if "mask" not in sample_kwargs
|
|
else sample_kwargs.pop("mask")
|
|
)
|
|
|
|
if layout == torch.sparse_coo:
|
|
mask = mask.to_sparse_coo().coalesce()
|
|
input = input.sparse_mask(mask)
|
|
elif layout == torch.sparse_csr:
|
|
if input.ndim != 2 or mask.ndim != 2:
|
|
continue
|
|
mask = mask.to_sparse_csr()
|
|
input = input.sparse_mask(mask)
|
|
|
|
# Binary operations currently only support same size masks
|
|
if is_binary(op):
|
|
if input.shape != sample_args[0].shape:
|
|
continue
|
|
# Binary operations also don't support kwargs right now
|
|
else:
|
|
sample_kwargs = {}
|
|
|
|
mt = masked_tensor(input, mask)
|
|
mt_args = self._convert_mt_args(sample_args, mask, layout)
|
|
|
|
mt_result = op(mt, *mt_args, **sample_kwargs)
|
|
t_result = op(sample.input, *sample_args, **sample_kwargs)
|
|
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
# If the operation is binary, check that lhs = masked, rhs = regular tensor also works
|
|
if is_binary(op) and layout == torch.strided:
|
|
mt_result2 = op(mt, *sample_args, **sample_kwargs)
|
|
_compare_mt_t(mt_result2, t_result)
|
|
|
|
def _test_reduction_equality(self, device, dtype, op, layout=torch.strided):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
for sample in samples:
|
|
input = sample.input
|
|
# Reduction operations don't support more advanced args/kwargs right now
|
|
sample_args, sample_kwargs = (), {}
|
|
|
|
if input.dim() == 0 or input.numel() == 0:
|
|
continue
|
|
|
|
mask = _create_random_mask(input.shape, device)
|
|
|
|
if torch.count_nonzero(mask) == 0:
|
|
continue
|
|
|
|
tensor_input = _combine_input_and_mask(op.op, input, mask)
|
|
if layout == torch.sparse_coo:
|
|
mask = mask.to_sparse_coo().coalesce()
|
|
input = input.sparse_mask(mask)
|
|
elif layout == torch.sparse_csr:
|
|
if input.ndim != 2 or mask.ndim != 2:
|
|
continue
|
|
mask = mask.to_sparse_csr()
|
|
input = input.sparse_mask(mask)
|
|
|
|
mt = masked_tensor(input, mask)
|
|
mt_args = self._convert_mt_args(sample_args, mask, layout)
|
|
|
|
mt_result = op(mt, *mt_args, **sample_kwargs)
|
|
t_result = op(tensor_input, *sample_args, **sample_kwargs)
|
|
|
|
_compare_mt_t(mt_result, t_result)
|
|
|
|
@ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
|
|
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
|
|
def test_unary_core(self, device, dtype, op, layout):
|
|
# Skip tests that don't have len(kwargs) == 0
|
|
skip_variants = {
|
|
"decimals_0",
|
|
"decimals_3",
|
|
"decimals_neg_3",
|
|
}
|
|
if op.name == "round" and op.variant_test_name in skip_variants:
|
|
return
|
|
self._test_unary_binary_equality(device, dtype, op)
|
|
|
|
@ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
|
|
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
|
|
def test_binary_core(self, device, dtype, op, layout):
|
|
self._test_unary_binary_equality(device, dtype, op, layout)
|
|
|
|
@ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
|
|
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
|
|
def test_reduction_all(self, device, dtype, op, layout):
|
|
# argmin and argmax are not currently supported for torch.sparse_csr
|
|
if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr:
|
|
return
|
|
|
|
self._test_reduction_equality(device, dtype, op, layout)
|
|
|
|
|
|
only_for = ("cpu", "cuda")
|
|
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
|
|
|
instantiate_device_type_tests(TestBasics, globals(), only_for=only_for)
|
|
instantiate_parametrized_tests(TestUnary)
|
|
instantiate_parametrized_tests(TestBinary)
|
|
instantiate_parametrized_tests(TestReductions)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|