578 lines
23 KiB
Python
578 lines
23 KiB
Python
# Owner(s): ["module: scatter & gather ops"]
|
|
|
|
from itertools import product
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
dtypes,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
gradcheck,
|
|
parametrize,
|
|
skipIfRocm,
|
|
)
|
|
|
|
|
|
reductions = ["max", "mean", "min", "sum", "prod"]
|
|
|
|
|
|
def get_default_value(initial_value, reduction):
|
|
if initial_value is not None:
|
|
return initial_value
|
|
if reduction == "max":
|
|
return -float("Inf")
|
|
elif reduction == "mean":
|
|
return float("nan")
|
|
elif reduction == "min":
|
|
return float("Inf")
|
|
elif reduction == "sum":
|
|
return 0.0
|
|
elif reduction == "prod":
|
|
return 1.0
|
|
|
|
|
|
class TestSegmentReductions(TestCase):
|
|
def _test_common(
|
|
self,
|
|
reduction,
|
|
device,
|
|
dtype,
|
|
unsafe,
|
|
axis,
|
|
initial_value,
|
|
data_arr,
|
|
lengths_arr,
|
|
expected_arr,
|
|
expected_grad_arr,
|
|
check_backward,
|
|
lengths_dtype=torch.int,
|
|
):
|
|
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
|
|
# generate offsets from lengths
|
|
zeros_shape = list(lengths.shape)
|
|
zeros_shape[-1] = 1
|
|
offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
|
|
|
|
data = torch.tensor(
|
|
data_arr,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
|
|
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
|
|
for mode in ['lengths', 'offsets']:
|
|
segment_reduce_kwargs = dict(
|
|
axis=axis,
|
|
unsafe=unsafe,
|
|
initial=initial_value)
|
|
if (mode == 'lengths'):
|
|
segment_reduce_kwargs['lengths'] = lengths
|
|
else:
|
|
segment_reduce_kwargs['offsets'] = offsets
|
|
actual_result = torch._segment_reduce(
|
|
data=data,
|
|
reduce=reduction,
|
|
**segment_reduce_kwargs
|
|
)
|
|
self.assertEqual(
|
|
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
|
|
)
|
|
|
|
if not check_backward:
|
|
return
|
|
|
|
# Test backward
|
|
actual_result.sum().backward()
|
|
self.assertEqual(
|
|
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
|
|
)
|
|
data = data.clone().detach().requires_grad_(True)
|
|
|
|
# gradcheck does not work well with bfloat16 or fp16 cpu types
|
|
# also there is small numerical difference with fp32
|
|
if dtype not in [torch.half, torch.bfloat16, torch.float]:
|
|
# gradcheck does not like "nan" input, setting to random 10
|
|
d_non_nan = np.nan_to_num(data_arr, nan=10)
|
|
new_data = torch.tensor(
|
|
# [10 if v == float("nan") else v for v in data],
|
|
d_non_nan,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
)
|
|
self.assertTrue(
|
|
gradcheck(
|
|
lambda x: torch._segment_reduce(
|
|
data=x,
|
|
reduce=reduction,
|
|
**segment_reduce_kwargs
|
|
),
|
|
(new_data,),
|
|
)
|
|
)
|
|
|
|
@dtypes(
|
|
*product(
|
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
|
(torch.int, torch.int64),
|
|
)
|
|
)
|
|
def test_simple_1d(self, device, dtypes):
|
|
val_dtype, length_type = dtypes
|
|
lengths = [1, 2, 3, 0]
|
|
data = [1, float("nan"), 3, 4, 5, 5]
|
|
|
|
for reduction in reductions:
|
|
for initial in [0, None]:
|
|
check_backward = True if initial is not None else False
|
|
initial_value = initial
|
|
default_value = get_default_value(initial_value, reduction)
|
|
if reduction == "max":
|
|
expected_result = [1, float("nan"), 5, default_value]
|
|
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
|
elif reduction == "mean":
|
|
expected_result = [1, float("nan"), 4.666, default_value]
|
|
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
|
elif reduction == "min":
|
|
if initial is not None:
|
|
initial_value = 1000 # some high number
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [1, float("nan"), 4, default_value]
|
|
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
|
|
elif reduction == "sum":
|
|
expected_result = [1, float("nan"), 14, default_value]
|
|
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
|
elif reduction == "prod":
|
|
if initial is not None:
|
|
initial_value = 2 # 0 initial_value will zero out everything for prod
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [2, float("nan"), 200, default_value]
|
|
expected_grad = [2.0, 6.0, float("nan"), 50.0, 40.0, 40.0]
|
|
else:
|
|
expected_result = [1, float("nan"), 100, default_value]
|
|
expected_grad = [1.0, 3.0, float("nan"), 25.0, 20.0, 20.0]
|
|
for axis in [0, -1]:
|
|
for unsafe in [True, False]:
|
|
self._test_common(
|
|
reduction,
|
|
device,
|
|
val_dtype,
|
|
unsafe,
|
|
axis,
|
|
initial_value,
|
|
data,
|
|
lengths,
|
|
expected_result,
|
|
expected_grad,
|
|
check_backward,
|
|
length_type,
|
|
)
|
|
|
|
@dtypes(
|
|
*product(
|
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
|
(torch.int, torch.int64),
|
|
)
|
|
)
|
|
def test_simple_zero_length(self, device, dtypes):
|
|
val_dtype, length_type = dtypes
|
|
lengths = [0, 0]
|
|
data = torch.ones(0)
|
|
|
|
for reduction in reductions:
|
|
for initial in [0, None]:
|
|
check_backward = True if initial is not None else False
|
|
initial_value = initial
|
|
default_value = get_default_value(initial_value, reduction)
|
|
if reduction == "max":
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
elif reduction == "mean":
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
elif reduction == "min":
|
|
if initial is not None:
|
|
initial_value = 1000 # some high number
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
elif reduction == "sum":
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
elif reduction == "prod":
|
|
if initial is not None:
|
|
initial_value = 2 # 0 initial_value will zero out everything for prod
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
else:
|
|
expected_result = [default_value, default_value]
|
|
expected_grad = []
|
|
for axis in [0]:
|
|
for unsafe in [True, False]:
|
|
self._test_common(
|
|
reduction,
|
|
device,
|
|
val_dtype,
|
|
unsafe,
|
|
axis,
|
|
initial_value,
|
|
data,
|
|
lengths,
|
|
expected_result,
|
|
expected_grad,
|
|
check_backward,
|
|
length_type,
|
|
)
|
|
|
|
@skipIfRocm
|
|
@dtypes(
|
|
*product(
|
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
|
(torch.int, torch.int64),
|
|
)
|
|
)
|
|
def test_multi_d_simple(self, device, dtypes):
|
|
val_dtype, length_type = dtypes
|
|
axis = 0
|
|
lengths = [1, 2, 3, 0]
|
|
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
|
|
|
|
for reduction in reductions:
|
|
for initial in [0, None]:
|
|
check_backward = True if initial is not None else False
|
|
initial_value = initial
|
|
default_value = get_default_value(initial_value, reduction)
|
|
if reduction == "max":
|
|
expected_result = [
|
|
[1, 1],
|
|
[float("nan"), float("nan")],
|
|
[4, 3],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[1, 1],
|
|
[1, 0],
|
|
[0, 1],
|
|
[1, 0],
|
|
[0, 0],
|
|
[0, 1],
|
|
]
|
|
elif reduction == "mean":
|
|
expected_result = [
|
|
[1, 1],
|
|
[float("nan"), float("nan")],
|
|
[3, 2],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[1.0, 1.0],
|
|
[0.5, 0.5],
|
|
[0.5, 0.5],
|
|
[0.333, 0.333],
|
|
[0.333, 0.333],
|
|
[0.333, 0.333],
|
|
]
|
|
elif reduction == "min":
|
|
if initial is not None:
|
|
initial_value = 1000 # some high number
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [
|
|
[1, 1],
|
|
[float("nan"), float("nan")],
|
|
[2, 1],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[1.0, 1.0],
|
|
[1, 0],
|
|
[0, 1],
|
|
[0, 1],
|
|
[0, 0],
|
|
[1, 0],
|
|
]
|
|
elif reduction == "sum":
|
|
expected_result = [
|
|
[1, 1],
|
|
[float("nan"), float("nan")],
|
|
[9, 6],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[1.0, 1.0],
|
|
[1.0, 1.0],
|
|
[1.0, 1.0],
|
|
[1.0, 1.0],
|
|
[1.0, 1.0],
|
|
[1.0, 1.0],
|
|
]
|
|
elif reduction == "prod":
|
|
if initial is not None:
|
|
initial_value = 2 # 0 initial_value will zero out everything for prod
|
|
default_value = get_default_value(initial_value, reduction)
|
|
expected_result = [
|
|
[2, 2],
|
|
[float("nan"), float("nan")],
|
|
[48, 12],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[2.0, 2.0],
|
|
[6.0, float("nan")],
|
|
[float("nan"), 2.0],
|
|
[12.0, 12.0],
|
|
[16.0, 6.0],
|
|
[24.0, 4.0],
|
|
]
|
|
else:
|
|
expected_result = [
|
|
[1, 1],
|
|
[float("nan"), float("nan")],
|
|
[24, 6],
|
|
[default_value, default_value],
|
|
]
|
|
expected_grad = [
|
|
[1.0, 1.0],
|
|
[3.0, float("nan")],
|
|
[float("nan"), 1.0],
|
|
[6.0, 6.0],
|
|
[8.0, 3.0],
|
|
[12.0, 2.0],
|
|
]
|
|
for unsafe in [True, False]:
|
|
self._test_common(
|
|
reduction,
|
|
device,
|
|
val_dtype,
|
|
unsafe,
|
|
axis,
|
|
initial_value,
|
|
data,
|
|
lengths,
|
|
expected_result,
|
|
expected_grad,
|
|
check_backward,
|
|
)
|
|
|
|
@dtypes(
|
|
*product(
|
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
|
(torch.int, torch.int64),
|
|
)
|
|
)
|
|
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
|
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
|
val_dtype, length_dtype = dtypes
|
|
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
|
tests = [
|
|
{
|
|
'src': [1, 2, 3, 4, 5, 6],
|
|
'index': [0, 0, 1, 1, 1, 3],
|
|
'indptr': [0, 2, 5, 5, 6],
|
|
'sum': [3, 12, 0, 6],
|
|
'prod': [2, 60, 1, 6],
|
|
'mean': [1.5, 4, float('nan'), 6],
|
|
'min': [1, 3, float('inf'), 6],
|
|
'max': [2, 5, -float('inf'), 6],
|
|
},
|
|
{
|
|
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
|
|
'index': [0, 0, 1, 1, 1, 3],
|
|
'indptr': [0, 2, 5, 5, 6],
|
|
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
|
|
'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
|
|
'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
|
|
'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
|
|
'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
|
|
},
|
|
{
|
|
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
|
|
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
|
|
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
|
|
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
|
|
'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
|
|
'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
|
|
'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
|
|
'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
|
|
},
|
|
{
|
|
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
|
|
'index': [[0, 0, 1], [0, 2, 2]],
|
|
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
|
|
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
|
|
'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
|
|
'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
|
|
[[7, 9], [float('nan'), float('nan')], [11, 12]]],
|
|
'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
|
|
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
|
|
'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
|
|
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
|
|
},
|
|
{
|
|
'src': [[1, 3], [2, 4]],
|
|
'index': [[0, 0], [0, 0]],
|
|
'indptr': [[0, 2], [0, 2]],
|
|
'sum': [[4], [6]],
|
|
'prod': [[3], [8]],
|
|
'mean': [[2], [3]],
|
|
'min': [[1], [2]],
|
|
'max': [[3], [4]],
|
|
},
|
|
{
|
|
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
|
|
'index': [[0, 0], [0, 0]],
|
|
'indptr': [[0, 2], [0, 2]],
|
|
'sum': [[[4, 4]], [[6, 6]]],
|
|
'prod': [[[3, 3]], [[8, 8]]],
|
|
'mean': [[[2, 2]], [[3, 3]]],
|
|
'min': [[[1, 1]], [[2, 2]]],
|
|
'max': [[[3, 3]], [[4, 4]]],
|
|
},
|
|
]
|
|
for test in tests:
|
|
data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True)
|
|
indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device)
|
|
dim = indptr.ndim - 1
|
|
# calculate lengths from indptr
|
|
lengths = torch.diff(indptr, dim=dim)
|
|
expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)
|
|
|
|
actual_result = torch._segment_reduce(
|
|
data=data,
|
|
reduce=reduce,
|
|
lengths=lengths,
|
|
axis=dim,
|
|
unsafe=True,
|
|
)
|
|
self.assertEqual(actual_result, expected)
|
|
|
|
# test offsets
|
|
actual_result = torch._segment_reduce(
|
|
data=data,
|
|
reduce=reduce,
|
|
offsets=indptr,
|
|
axis=dim,
|
|
unsafe=True,
|
|
)
|
|
self.assertEqual(actual_result, expected)
|
|
|
|
if val_dtype == torch.float64:
|
|
def fn(x, mode='lengths'):
|
|
initial = 1
|
|
# supply initial values to prevent gradcheck from failing for 0 length segments
|
|
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
|
|
if reduce == 'min':
|
|
initial = 1000
|
|
elif reduce == 'max':
|
|
initial = -1000
|
|
segment_reduce_args = {x, reduce}
|
|
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
|
|
if mode == 'lengths':
|
|
segment_reduce_kwargs[mode] = lengths
|
|
elif mode == 'offsets':
|
|
segment_reduce_kwargs[mode] = indptr
|
|
return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
|
|
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
|
|
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
|
|
|
|
|
|
@dtypes(
|
|
*product(
|
|
(torch.half, torch.bfloat16, torch.float, torch.double),
|
|
(torch.int, torch.int64),
|
|
)
|
|
)
|
|
def test_multi_d(self, device, dtypes):
|
|
val_dtype, length_type = dtypes
|
|
axis = 0
|
|
lengths = [0, 2, 3, 0]
|
|
data = np.arange(50).reshape(5, 2, 5).tolist()
|
|
expected_grad = []
|
|
|
|
# TODO: calculate grad and check correctness
|
|
check_backward = False
|
|
|
|
for reduction in reductions:
|
|
initial_value = 0
|
|
if reduction == "max":
|
|
expected_result = [
|
|
np.full((2, 5), initial_value).tolist(),
|
|
np.max(data[:2], axis=0).tolist(),
|
|
np.max(data[2:], axis=0).tolist(),
|
|
np.full((2, 5), initial_value).tolist(),
|
|
]
|
|
elif reduction == "mean":
|
|
expected_result = [
|
|
np.full((2, 5), initial_value).tolist(),
|
|
np.mean(data[:2], axis=0).tolist(),
|
|
np.mean(data[2:], axis=0).tolist(),
|
|
np.full((2, 5), initial_value).tolist(),
|
|
]
|
|
elif reduction == "min":
|
|
initial_value = 1000 # some high number
|
|
expected_result = [
|
|
np.full((2, 5), initial_value).tolist(),
|
|
np.min(data[:2], axis=0).tolist(),
|
|
np.min(data[2:], axis=0).tolist(),
|
|
np.full((2, 5), initial_value).tolist(),
|
|
]
|
|
elif reduction == "sum":
|
|
expected_result = [
|
|
np.full((2, 5), initial_value).tolist(),
|
|
np.sum(data[:2], axis=0).tolist(),
|
|
np.sum(data[2:], axis=0).tolist(),
|
|
np.full((2, 5), initial_value).tolist(),
|
|
]
|
|
elif reduction == "prod":
|
|
initial_value = 1
|
|
expected_result = [
|
|
np.full((2, 5), initial_value).tolist(),
|
|
np.prod(data[:2], axis=0).tolist(),
|
|
np.prod(data[2:], axis=0).tolist(),
|
|
np.full((2, 5), initial_value).tolist(),
|
|
]
|
|
for unsafe in [True, False]:
|
|
self._test_common(
|
|
reduction,
|
|
device,
|
|
val_dtype,
|
|
unsafe,
|
|
axis,
|
|
initial_value,
|
|
data,
|
|
lengths,
|
|
expected_result,
|
|
expected_grad,
|
|
check_backward,
|
|
)
|
|
|
|
@dtypes(torch.int, torch.int64)
|
|
def test_unsafe_flag(self, device, dtype):
|
|
length_type = dtype
|
|
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
|
|
data = torch.arange(6, dtype=torch.float, device=device)
|
|
|
|
# test for error on 1-D lenghts
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
|
torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
|
|
|
|
# test for error on multi-D lengths
|
|
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
|
|
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
|
torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
|
|
|
|
|
|
|
|
|
|
instantiate_device_type_tests(TestSegmentReductions, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|