pytorch/test/ao/sparsity/test_sparsity_utils.py

150 lines
5.6 KiB
Python

# Owner(s): ["module: unknown"]
import logging
import torch
from torch.ao.pruning.sparsifier.utils import (
fqn_to_module,
get_arg_info_from_tensor_fqn,
module_to_fqn,
)
from torch.testing._internal.common_quantization import (
ConvBnReLUModel,
ConvModel,
FunctionalLinear,
LinearAddModel,
ManualEmbeddingBagLinear,
SingleLayerLinearModel,
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
model_list = [
ConvModel,
SingleLayerLinearModel,
TwoLayerLinearModel,
LinearAddModel,
ConvBnReLUModel,
ManualEmbeddingBagLinear,
FunctionalLinear,
]
class TestSparsityUtilFunctions(TestCase):
def test_module_to_fqn(self):
"""
Tests that module_to_fqn works as expected when compared to known good
module.get_submodule(fqn) function
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
fqn = module_to_fqn(model, module)
check_module = model.get_submodule(fqn)
self.assertEqual(module, check_module)
def test_module_to_fqn_fail(self):
"""
Tests that module_to_fqn returns None when an fqn that doesn't
correspond to a path to a node/tensor is given
"""
for model_class in model_list:
model = model_class()
fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
self.assertEqual(fqn, None)
def test_module_to_fqn_root(self):
"""
Tests that module_to_fqn returns '' when model and target module are the same
"""
for model_class in model_list:
model = model_class()
fqn = module_to_fqn(model, model)
self.assertEqual(fqn, "")
def test_fqn_to_module(self):
"""
Tests that fqn_to_module operates as inverse
of module_to_fqn
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
fqn = module_to_fqn(model, module)
check_module = fqn_to_module(model, fqn)
self.assertEqual(module, check_module)
def test_fqn_to_module_fail(self):
"""
Tests that fqn_to_module returns None when it tries to
find an fqn of a module outside the model
"""
for model_class in model_list:
model = model_class()
fqn = "foo.bar.baz"
check_module = fqn_to_module(model, fqn)
self.assertEqual(check_module, None)
def test_fqn_to_module_for_tensors(self):
"""
Tests that fqn_to_module works for tensors, actually all parameters
of the model. This is tested by identifying a module with a tensor,
and generating the tensor_fqn using module_to_fqn on the module +
the name of the tensor.
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
module_fqn = module_to_fqn(model, module)
for tensor_name, tensor in module.named_parameters(recurse=False):
tensor_fqn = ( # string manip to handle tensors on root
module_fqn + ("." if module_fqn != "" else "") + tensor_name
)
check_tensor = fqn_to_module(model, tensor_fqn)
self.assertEqual(tensor, check_tensor)
def test_get_arg_info_from_tensor_fqn(self):
"""
Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
then compares with known (parent) module and tensor_name as well as module_fqn
from module_to_fqn.
"""
for model_class in model_list:
model = model_class()
list_of_modules = [m for _, m in model.named_modules()] + [model]
for module in list_of_modules:
module_fqn = module_to_fqn(model, module)
for tensor_name, tensor in module.named_parameters(recurse=False):
tensor_fqn = (
module_fqn + ("." if module_fqn != "" else "") + tensor_name
)
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
self.assertEqual(arg_info["module"], module)
self.assertEqual(arg_info["module_fqn"], module_fqn)
self.assertEqual(arg_info["tensor_name"], tensor_name)
self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
def test_get_arg_info_from_tensor_fqn_fail(self):
"""
Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
inputs. The string outputs still work but the output module is expected to be None.
"""
for model_class in model_list:
model = model_class()
tensor_fqn = "foo.bar.baz"
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
self.assertEqual(arg_info["module"], None)
self.assertEqual(arg_info["module_fqn"], "foo.bar")
self.assertEqual(arg_info["tensor_name"], "baz")
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")