150 lines
5.6 KiB
Python
150 lines
5.6 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
|
|
import logging
|
|
|
|
import torch
|
|
from torch.ao.pruning.sparsifier.utils import (
|
|
fqn_to_module,
|
|
get_arg_info_from_tensor_fqn,
|
|
module_to_fqn,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
ConvBnReLUModel,
|
|
ConvModel,
|
|
FunctionalLinear,
|
|
LinearAddModel,
|
|
ManualEmbeddingBagLinear,
|
|
SingleLayerLinearModel,
|
|
TwoLayerLinearModel,
|
|
)
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
|
)
|
|
|
|
model_list = [
|
|
ConvModel,
|
|
SingleLayerLinearModel,
|
|
TwoLayerLinearModel,
|
|
LinearAddModel,
|
|
ConvBnReLUModel,
|
|
ManualEmbeddingBagLinear,
|
|
FunctionalLinear,
|
|
]
|
|
|
|
|
|
class TestSparsityUtilFunctions(TestCase):
|
|
def test_module_to_fqn(self):
|
|
"""
|
|
Tests that module_to_fqn works as expected when compared to known good
|
|
module.get_submodule(fqn) function
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
|
for module in list_of_modules:
|
|
fqn = module_to_fqn(model, module)
|
|
check_module = model.get_submodule(fqn)
|
|
self.assertEqual(module, check_module)
|
|
|
|
def test_module_to_fqn_fail(self):
|
|
"""
|
|
Tests that module_to_fqn returns None when an fqn that doesn't
|
|
correspond to a path to a node/tensor is given
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
|
|
self.assertEqual(fqn, None)
|
|
|
|
def test_module_to_fqn_root(self):
|
|
"""
|
|
Tests that module_to_fqn returns '' when model and target module are the same
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
fqn = module_to_fqn(model, model)
|
|
self.assertEqual(fqn, "")
|
|
|
|
def test_fqn_to_module(self):
|
|
"""
|
|
Tests that fqn_to_module operates as inverse
|
|
of module_to_fqn
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
|
for module in list_of_modules:
|
|
fqn = module_to_fqn(model, module)
|
|
check_module = fqn_to_module(model, fqn)
|
|
self.assertEqual(module, check_module)
|
|
|
|
def test_fqn_to_module_fail(self):
|
|
"""
|
|
Tests that fqn_to_module returns None when it tries to
|
|
find an fqn of a module outside the model
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
fqn = "foo.bar.baz"
|
|
check_module = fqn_to_module(model, fqn)
|
|
self.assertEqual(check_module, None)
|
|
|
|
def test_fqn_to_module_for_tensors(self):
|
|
"""
|
|
Tests that fqn_to_module works for tensors, actually all parameters
|
|
of the model. This is tested by identifying a module with a tensor,
|
|
and generating the tensor_fqn using module_to_fqn on the module +
|
|
the name of the tensor.
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
|
for module in list_of_modules:
|
|
module_fqn = module_to_fqn(model, module)
|
|
for tensor_name, tensor in module.named_parameters(recurse=False):
|
|
tensor_fqn = ( # string manip to handle tensors on root
|
|
module_fqn + ("." if module_fqn != "" else "") + tensor_name
|
|
)
|
|
check_tensor = fqn_to_module(model, tensor_fqn)
|
|
self.assertEqual(tensor, check_tensor)
|
|
|
|
def test_get_arg_info_from_tensor_fqn(self):
|
|
"""
|
|
Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
|
|
Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
|
|
then compares with known (parent) module and tensor_name as well as module_fqn
|
|
from module_to_fqn.
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
list_of_modules = [m for _, m in model.named_modules()] + [model]
|
|
for module in list_of_modules:
|
|
module_fqn = module_to_fqn(model, module)
|
|
for tensor_name, tensor in module.named_parameters(recurse=False):
|
|
tensor_fqn = (
|
|
module_fqn + ("." if module_fqn != "" else "") + tensor_name
|
|
)
|
|
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
|
|
self.assertEqual(arg_info["module"], module)
|
|
self.assertEqual(arg_info["module_fqn"], module_fqn)
|
|
self.assertEqual(arg_info["tensor_name"], tensor_name)
|
|
self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
|
|
|
|
def test_get_arg_info_from_tensor_fqn_fail(self):
|
|
"""
|
|
Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
|
|
inputs. The string outputs still work but the output module is expected to be None.
|
|
"""
|
|
for model_class in model_list:
|
|
model = model_class()
|
|
tensor_fqn = "foo.bar.baz"
|
|
arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
|
|
self.assertEqual(arg_info["module"], None)
|
|
self.assertEqual(arg_info["module_fqn"], "foo.bar")
|
|
self.assertEqual(arg_info["tensor_name"], "baz")
|
|
self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
|