83 lines
3.6 KiB
Python
83 lines
3.6 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import hypothesis.strategies as st
|
|
from hypothesis import given
|
|
import numpy as np
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch.testing._internal.hypothesis_utils as hu
|
|
hu.assert_deadline_disabled()
|
|
|
|
|
|
class PruningOpTest(TestCase):
|
|
|
|
# Generate rowwise mask vector based on indicator and threshold value.
|
|
# indicator is a vector that contains one value per weight row and it
|
|
# represents the importance of a row.
|
|
# We mask a row if its indicator value is less than the threshold.
|
|
def _generate_rowwise_mask(self, embedding_rows):
|
|
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
|
|
threshold = float(np.random.random_sample())
|
|
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
|
|
return mask
|
|
|
|
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
|
|
embedding_weights = None
|
|
if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
|
embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
|
|
else:
|
|
embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
|
|
mask = self._generate_rowwise_mask(embedding_rows)
|
|
|
|
def get_pt_result(embedding_weights, mask, indices_type):
|
|
return torch._rowwise_prune(embedding_weights, mask, indices_type)
|
|
|
|
# Reference implementation.
|
|
def get_reference_result(embedding_weights, mask, indices_type):
|
|
num_embeddings = mask.size()[0]
|
|
compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
|
|
pruned_weights_out = embedding_weights[mask[:]]
|
|
idx = 0
|
|
for i in range(mask.size()[0]):
|
|
if mask[i]:
|
|
compressed_idx_out[i] = idx
|
|
idx = idx + 1
|
|
else:
|
|
compressed_idx_out[i] = -1
|
|
return (pruned_weights_out, compressed_idx_out)
|
|
|
|
pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
|
|
embedding_weights, mask, indices_type)
|
|
ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
|
|
embedding_weights, mask, indices_type)
|
|
|
|
torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
|
|
self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
|
|
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
|
|
|
|
|
|
@given(
|
|
embedding_rows=st.integers(1, 100),
|
|
embedding_dims=st.integers(1, 100),
|
|
weights_dtype=st.sampled_from([torch.float64, torch.float32,
|
|
torch.float16, torch.int8,
|
|
torch.int16, torch.int32, torch.int64])
|
|
)
|
|
def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
|
|
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
|
|
|
|
|
|
@given(
|
|
embedding_rows=st.integers(1, 100),
|
|
embedding_dims=st.integers(1, 100),
|
|
weights_dtype=st.sampled_from([torch.float64, torch.float32,
|
|
torch.float16, torch.int8,
|
|
torch.int16, torch.int32, torch.int64])
|
|
)
|
|
def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
|
|
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|