pytorch/test/test_module_init.py

540 lines
24 KiB
Python

# Owner(s): ["module: nn"]
import inspect
import torch
from unittest import mock
from unittest.mock import MagicMock, patch
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, run_tests
# Returns a database of args & kwargs that can be used to construct each module.
# Each entry is in class -> (args, kwargs) format.
# Example: torch.nn.Linear -> ([10, 5], {})
# TODO: Merge this in with the initial ModuleInfo implementation.
def build_constructor_arg_db():
return {
torch.nn.AdaptiveAvgPool1d: ((5,), {}),
torch.nn.AdaptiveAvgPool2d: ((5,), {}),
torch.nn.AdaptiveAvgPool3d: ((5,), {}),
torch.nn.AdaptiveLogSoftmaxWithLoss: ((100, 20, [5, 10, 15]), {}),
torch.nn.AdaptiveMaxPool1d: ((5,), {}),
torch.nn.AdaptiveMaxPool2d: ((5,), {}),
torch.nn.AdaptiveMaxPool3d: ((5,), {}),
torch.nn.AlphaDropout: ((), {}),
torch.nn.AvgPool1d: ((3,), {}),
torch.nn.AvgPool2d: ((3,), {}),
torch.nn.AvgPool3d: ((3,), {}),
torch.nn.BCELoss: ((), {}),
torch.nn.BCEWithLogitsLoss: ((), {}),
torch.nn.BatchNorm1d: ((5,), {}),
torch.nn.BatchNorm2d: ((5,), {}),
torch.nn.BatchNorm3d: ((5,), {}),
torch.nn.Bilinear: ((2, 3, 4), {}),
torch.nn.CELU: ((), {}),
torch.nn.CTCLoss: ((), {}),
torch.nn.ChannelShuffle: ((4,), {}),
torch.nn.ConstantPad1d: ((2, 3.5), {}),
torch.nn.ConstantPad2d: ((2, 3.5), {}),
torch.nn.ConstantPad3d: ((2, 3.5), {}),
torch.nn.CircularPad1d: ((2,), {}),
torch.nn.CircularPad2d: ((2,), {}),
torch.nn.CircularPad3d: ((2,), {}),
torch.nn.Conv1d: ((3, 3, 3), {}),
torch.nn.Conv2d: ((3, 3, 3), {}),
torch.nn.Conv3d: ((3, 3, 3), {}),
torch.nn.ConvTranspose1d: ((3, 3, 3), {}),
torch.nn.ConvTranspose2d: ((3, 3, 3), {}),
torch.nn.ConvTranspose3d: ((3, 3, 3), {}),
torch.nn.CosineEmbeddingLoss: ((), {}),
torch.nn.CosineSimilarity: ((), {}),
torch.nn.CrossEntropyLoss: ((), {}),
torch.nn.CrossMapLRN2d: ((5,), {}),
torch.nn.Dropout1d: ((), {}),
torch.nn.Dropout2d: ((), {}),
torch.nn.Dropout3d: ((), {}),
torch.nn.Dropout: ((), {}),
torch.nn.ELU: ((), {}),
torch.nn.Embedding: ((10, 5), {}),
torch.nn.EmbeddingBag: ((10, 5), {}),
torch.nn.FeatureAlphaDropout: ((), {}),
torch.nn.Flatten: ((), {}),
torch.nn.Fold: ((5, 2), {}),
torch.nn.FractionalMaxPool2d: ((5, 2), {}),
torch.nn.FractionalMaxPool3d: ((5, 2), {}),
torch.nn.GELU: ((), {}),
torch.nn.GLU: ((), {}),
torch.nn.GRU: ((5, 10), {}),
torch.nn.GRUCell: ((5, 10), {}),
torch.nn.GaussianNLLLoss: ((), {}),
torch.nn.GroupNorm: ((3, 6, 1e-5, True), {}),
torch.nn.Hardshrink: ((), {}),
torch.nn.Hardsigmoid: ((), {}),
torch.nn.Hardswish: ((), {}),
torch.nn.Hardtanh: ((), {}),
torch.nn.HingeEmbeddingLoss: ((), {}),
torch.nn.HuberLoss: ((), {}),
torch.nn.Identity: ((), {}),
torch.nn.InstanceNorm1d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm2d: ((5, 1e-5, 0.1, True), {}),
torch.nn.InstanceNorm3d: ((5, 1e-5, 0.1, True), {}),
torch.nn.KLDivLoss: ((), {}),
torch.nn.L1Loss: ((), {}),
torch.nn.LPPool1d: ((2, 3), {}),
torch.nn.LPPool2d: ((2, 3), {}),
torch.nn.LSTM: ((5, 10), {}),
torch.nn.LSTMCell: ((5, 10), {}),
torch.nn.LayerNorm: ((2,), {}),
torch.nn.LazyBatchNorm1d: ((), {}),
torch.nn.LazyBatchNorm2d: ((), {}),
torch.nn.LazyBatchNorm3d: ((), {}),
torch.nn.LazyConv1d: ((5, 2), {}),
torch.nn.LazyConv2d: ((5, 2), {}),
torch.nn.LazyConv3d: ((5, 2), {}),
torch.nn.LazyConvTranspose1d: ((5, 2), {}),
torch.nn.LazyConvTranspose2d: ((5, 2), {}),
torch.nn.LazyConvTranspose3d: ((5, 2), {}),
torch.nn.LazyInstanceNorm1d: ((), {}),
torch.nn.LazyInstanceNorm2d: ((), {}),
torch.nn.LazyInstanceNorm3d: ((), {}),
torch.nn.LazyLinear: ((5,), {}),
torch.nn.LeakyReLU: ((), {}),
torch.nn.Linear: ((10, 5), {}),
torch.nn.LocalResponseNorm: ((2,), {}),
torch.nn.LogSigmoid: ((), {}),
torch.nn.LogSoftmax: ((), {}),
torch.nn.MSELoss: ((), {}),
torch.nn.MarginRankingLoss: ((), {}),
torch.nn.MaxPool1d: ((3,), {}),
torch.nn.MaxPool2d: ((3,), {}),
torch.nn.MaxPool3d: ((3,), {}),
torch.nn.MaxUnpool1d: ((5,), {}),
torch.nn.MaxUnpool2d: ((5,), {}),
torch.nn.MaxUnpool3d: ((5,), {}),
torch.nn.Mish: ((), {}),
torch.nn.ModuleDict: ((), {}),
torch.nn.ModuleList: ((), {}),
torch.nn.MultiLabelMarginLoss: ((), {}),
torch.nn.MultiLabelSoftMarginLoss: ((), {}),
torch.nn.MultiMarginLoss: ((), {}),
torch.nn.MultiheadAttention: ((100, 2), {}),
torch.nn.NLLLoss2d: ((), {}),
torch.nn.NLLLoss: ((), {}),
torch.nn.PReLU: ((), {}),
torch.nn.PairwiseDistance: ((), {}),
torch.nn.ParameterDict: ((), {}),
torch.nn.ParameterList: ((), {}),
torch.nn.PixelShuffle: ((2,), {}),
torch.nn.PixelUnshuffle: ((2,), {}),
torch.nn.PoissonNLLLoss: ((), {}),
torch.nn.RNN: ((5, 10), {}),
torch.nn.RNNBase: (('LSTM', 5, 10), {}),
torch.nn.RNNCell: ((5, 10), {}),
torch.nn.RNNCellBase: ((5, 10, True, 2), {}),
torch.nn.RReLU: ((), {}),
torch.nn.ReLU6: ((), {}),
torch.nn.ReLU: ((), {}),
torch.nn.ReflectionPad1d: ((2,), {}),
torch.nn.ReflectionPad2d: ((2,), {}),
torch.nn.ReflectionPad3d: ((2,), {}),
torch.nn.ReplicationPad1d: ((2,), {}),
torch.nn.ReplicationPad2d: ((2,), {}),
torch.nn.ReplicationPad3d: ((2,), {}),
torch.nn.SELU: ((), {}),
torch.nn.Sequential: ((), {}),
torch.nn.SiLU: ((), {}),
torch.nn.Sigmoid: ((), {}),
torch.nn.SmoothL1Loss: ((), {}),
torch.nn.SoftMarginLoss: ((), {}),
torch.nn.Softmax2d: ((), {}),
torch.nn.Softmax: ((), {}),
torch.nn.Softmin: ((), {}),
torch.nn.Softplus: ((), {}),
torch.nn.Softshrink: ((), {}),
torch.nn.Softsign: ((), {}),
torch.nn.SyncBatchNorm: ((5,), {}),
torch.nn.Tanh: ((), {}),
torch.nn.Tanhshrink: ((), {}),
torch.nn.Threshold: ((0.1, 20), {}),
torch.nn.Transformer: ((), {}),
torch.nn.TransformerDecoder: ((torch.nn.TransformerDecoderLayer, 3), {}),
torch.nn.TransformerDecoderLayer: ((10, 2), {}),
torch.nn.TransformerEncoder: ((torch.nn.TransformerEncoderLayer, 3), {}),
torch.nn.TransformerEncoderLayer: ((10, 2), {}),
torch.nn.TripletMarginLoss: ((), {}),
torch.nn.TripletMarginWithDistanceLoss: ((), {}),
torch.nn.Unflatten: ((1, (2, 5, 5)), {}),
torch.nn.Unfold: ((3,), {}),
torch.nn.Upsample: ((), {}),
torch.nn.UpsamplingBilinear2d: ((), {}),
torch.nn.UpsamplingNearest2d: ((), {}),
torch.nn.ZeroPad1d: ((0,), {}),
torch.nn.ZeroPad2d: ((0,), {}),
torch.nn.ZeroPad3d: ((0,), {}),
torch.ao.nn.qat.Conv1d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv2d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv3d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Linear: ((5, 2), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Embedding: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.ao.nn.qat.EmbeddingBag: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.ao.nn.quantizable.LSTM: ((5, 6), {}),
torch.ao.nn.quantizable.LSTMCell: ((5, 6), {}),
torch.ao.nn.quantizable.MultiheadAttention: ((10, 2), {}),
torch.ao.nn.quantized.BatchNorm2d: ((2,), {}),
torch.ao.nn.quantized.BatchNorm3d: ((2,), {}),
torch.ao.nn.quantized.Dropout: ((), {}),
torch.ao.nn.quantized.Conv1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv3d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
'stride': (2, 1, 1),
'padding': (4, 2, 2),
'output_padding': (2, 2, 2),
'dilation': (1, 1, 1),
}),
torch.ao.nn.quantized.DeQuantize: ((), {}),
torch.ao.nn.quantized.ELU: ((0.01, 0), {}),
torch.ao.nn.quantized.Embedding: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.EmbeddingBag: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.GroupNorm: ((2, 4, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.Hardswish: ((0.1, 0,), {}),
torch.ao.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LeakyReLU: ((0.01, 0), {}),
torch.ao.nn.quantized.Linear: ((5, 2), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.MaxPool2d: ((3,), {}),
torch.ao.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
}),
torch.ao.nn.quantized.ReLU6: ((), {}),
torch.ao.nn.quantized.Sigmoid: ((0.1, 0), {}),
torch.ao.nn.quantized.Softmax: ((), {}),
torch.ao.nn.quantized.FloatFunctional: ((), {}),
torch.ao.nn.quantized.FXFloatFunctional: ((), {}),
torch.ao.nn.quantized.QFunctional: ((), {}),
# Remove torch.ao.nn.quantized after the migration completes:
torch.ao.nn.qat.Conv1d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv2d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Conv3d: ((3, 3, 3), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Linear: ((5, 2), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.ao.nn.qat.Embedding: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.ao.nn.qat.EmbeddingBag: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.ao.nn.quantized.BatchNorm2d: ((2,), {}),
torch.ao.nn.quantized.BatchNorm3d: ((2,), {}),
torch.ao.nn.quantized.Dropout: ((), {}),
torch.ao.nn.quantized.Conv1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.Conv3d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}),
torch.ao.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), {
'stride': (2, 1, 1),
'padding': (4, 2, 2),
'output_padding': (2, 2, 2),
'dilation': (1, 1, 1),
}),
torch.ao.nn.quantized.DeQuantize: ((), {}),
torch.ao.nn.quantized.ELU: ((0.01, 0), {}),
torch.ao.nn.quantized.Embedding: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.EmbeddingBag: ((10, 3), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.GroupNorm: ((2, 4, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.Hardswish: ((0.1, 0,), {}),
torch.ao.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)),
torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}),
torch.ao.nn.quantized.LeakyReLU: ((0.01, 0), {}),
torch.ao.nn.quantized.Linear: ((5, 2), {
'factory_kwargs': {},
}),
torch.ao.nn.quantized.MaxPool2d: ((3,), {}),
torch.ao.nn.quantized.PReLU: ((0.01, 0), {}),
torch.ao.nn.quantized.Quantize: ((0.1, 0), {
'dtype': torch.int16,
'factory_kwargs': {},
}),
torch.ao.nn.quantized.ReLU6: ((), {}),
torch.ao.nn.quantized.Sigmoid: ((0.1, 0), {}),
torch.ao.nn.quantized.Softmax: ((), {}),
torch.ao.nn.quantized.FloatFunctional: ((), {}),
torch.ao.nn.quantized.FXFloatFunctional: ((), {}),
torch.ao.nn.quantized.QFunctional: ((), {}),
}
# Instantiates the given class with the given args, kwargs, optionally on a given device.
def instantiate_class(cls, args, kwargs, extra_kwargs):
return cls(*args, **kwargs) if extra_kwargs is None else cls(*args, **kwargs, **extra_kwargs)
# Returns a function that calls the real implementation of a method
# in addition to passing args to a mock object.
def mock_wrapper(method):
mock = MagicMock()
def wrapper(self, *args, **kwargs):
mock(*args, **kwargs)
return method(self, *args, **kwargs)
wrapper.mock = mock
return wrapper
# Returns a set of args / kwargs that can be used to construct the module.
def get_example_args(module_cls, constructor_arg_db, extra_kwargs=None):
assert module_cls in constructor_arg_db, \
f"No entry for {module_cls} in the constructor arg DB. Please add it to pass these tests."
args, kwargs = constructor_arg_db[module_cls]
extra_kwargs = {} if extra_kwargs is None else extra_kwargs
# Recursively instantiate args / kwargs that are class objects.
args = [instantiate_class(arg, *get_example_args(arg, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(arg) else torch.nn.Parameter(arg.to(**extra_kwargs))
if isinstance(arg, torch.nn.Parameter) else arg for arg in args]
kwargs = {k: instantiate_class(v, *get_example_args(v, constructor_arg_db), extra_kwargs=extra_kwargs)
if inspect.isclass(v) else torch.nn.Parameter(v.to(*extra_kwargs))
if isinstance(v, torch.nn.Parameter) else v for k, v in kwargs.items()}
kwargs.update(extra_kwargs)
return args, kwargs
def generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=True, module_is_lazy=False, check_nonexistent_arg=True):
# Generate a function for testing the given module.
@dtypes(*floating_types())
def run_test(test_cls, device, dtype, module_cls=module_cls):
# Check if this module creates parameters or registers buffers.
# The mock magic here passes through to the real Parameter / register_buffer
# logic and is only used to check for calls.
args, kwargs = get_example_args(module_cls, constructor_arg_db)
# Some modules need to pass factory_kwargs so as not to conflict with existing args such as dtype.
module_needs_factory_kwargs = 'factory_kwargs' in kwargs
if module_needs_factory_kwargs:
del kwargs['factory_kwargs']
extra_kwargs = {
'factory_kwargs': {
'device': device,
'dtype': dtype,
}
}
else:
extra_kwargs = {
'device': device,
'dtype': dtype,
}
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
m = module_cls(*args, **kwargs)
module_creates_params_or_buffers = parameter_new.mock.called or register_buffer.mock.called
# == Verify factory kwargs are supported. ==
if verify_kwargs and module_creates_params_or_buffers:
args, kwargs = get_example_args(module_cls, constructor_arg_db,
extra_kwargs=extra_kwargs)
if module_is_lazy:
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
m = module_cls(*args, **kwargs)
uninit_param_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
uninit_buffer_new.mock.assert_has_calls(
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
else:
# Check device placement and dtype for parameters and buffers.
# Only verify floating point dtypes since that's what the kwarg applies to.
# Note that dtype verification is also skipped if the module requires factory_kwargs.
m = module_cls(*args, **kwargs)
for name, param in m.named_parameters():
test_cls.assertEqual(
str(param.device), device,
f'Parameter {name} is on {param.device.type} instead of the expected device {device}')
if param.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
param.dtype, dtype,
f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}')
for name, buffer in m.named_buffers():
test_cls.assertEqual(
str(buffer.device), device,
f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}')
if buffer.dtype.is_floating_point and not module_needs_factory_kwargs:
test_cls.assertEqual(
buffer.dtype, dtype,
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
# == Verify passing a nonexistent arg errors out. ==
if check_nonexistent_arg:
with test_cls.assertRaises(TypeError):
m = module_cls(*args, **kwargs, nonexistent_arg='foo')
return run_test
def generate_tests(test_cls, constructor_arg_db):
# test all modules underneath these namespaces...
NAMESPACES = [
torch.nn,
torch.ao.nn.qat,
torch.ao.nn.quantized,
torch.ao.nn.qat,
torch.ao.nn.quantizable,
torch.ao.nn.quantized,
]
# ...except these
MODULES_TO_SKIP = {
torch.nn.Module,
torch.nn.Container, # deprecated
torch.nn.NLLLoss2d, # deprecated
# TODO: Remove these 4 from this list once the ASan issue is fixed.
# See https://github.com/pytorch/pytorch/issues/55396
torch.ao.nn.quantized.Embedding,
torch.ao.nn.quantized.EmbeddingBag,
torch.ao.nn.quantized.Embedding,
torch.ao.nn.quantized.EmbeddingBag,
torch.ao.nn.quantized.LSTM,
torch.ao.nn.quantized.MultiheadAttention,
}
# no need to support kwargs for these modules even though
# they have parameters / buffers because they are passed in
# already instantiated s
MODULES_WITHOUT_KWARGS_SUPPORT = {
torch.nn.BCELoss,
torch.nn.BCEWithLogitsLoss,
torch.nn.CrossEntropyLoss,
torch.nn.FractionalMaxPool2d,
torch.nn.FractionalMaxPool3d,
torch.nn.MultiLabelSoftMarginLoss,
torch.nn.MultiMarginLoss,
torch.nn.NLLLoss,
torch.nn.TransformerDecoder,
torch.nn.TransformerEncoder,
}
# modules that supported kwargs before
MODULES_WITH_PREVIOUS_KWARGS = {
torch.nn.Identity,
}
# lazy modules don't instantiate parameters right away
LAZY_MODULES = {
torch.nn.LazyBatchNorm1d,
torch.nn.LazyBatchNorm2d,
torch.nn.LazyBatchNorm3d,
torch.nn.LazyConv1d,
torch.nn.LazyConv2d,
torch.nn.LazyConv3d,
torch.nn.LazyConvTranspose1d,
torch.nn.LazyConvTranspose2d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyConvTranspose3d,
torch.nn.LazyInstanceNorm1d,
torch.nn.LazyInstanceNorm2d,
torch.nn.LazyInstanceNorm3d,
torch.nn.LazyLinear,
}
# these modules requires FBGEMM backend to instantiate
MODULES_THAT_REQUIRE_FBGEMM = {
torch.ao.nn.quantized.Conv1d,
torch.ao.nn.quantized.Conv2d,
torch.ao.nn.quantized.Conv3d,
torch.ao.nn.quantized.ConvTranspose1d,
torch.ao.nn.quantized.ConvTranspose2d,
torch.ao.nn.quantized.ConvTranspose3d,
torch.ao.nn.quantized.Linear,
# Remove the lines below after AO migration is complete
torch.ao.nn.quantized.Conv1d,
torch.ao.nn.quantized.Conv2d,
torch.ao.nn.quantized.Conv3d,
torch.ao.nn.quantized.ConvTranspose1d,
torch.ao.nn.quantized.ConvTranspose2d,
torch.ao.nn.quantized.ConvTranspose3d,
torch.ao.nn.quantized.Linear,
}
for namespace in NAMESPACES:
# the "nn" in "torch.nn"
namespace_basename = namespace.__name__.split('.')[-1]
for module_name in namespace.modules.__all__:
# class object for this module (e.g. torch.nn.Linear)
module_cls = getattr(namespace.modules, module_name)
if module_cls in MODULES_TO_SKIP:
continue
verify_kwargs = module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT
module_is_lazy = module_cls in LAZY_MODULES
check_nonexistent_arg = module_cls not in MODULES_WITH_PREVIOUS_KWARGS
# Generate a function for testing this module and setattr it onto the test class.
run_test = generate_test_func(test_cls, module_cls, constructor_arg_db,
verify_kwargs=verify_kwargs,
module_is_lazy=module_is_lazy,
check_nonexistent_arg=check_nonexistent_arg)
test_name = f'test_{namespace_basename}_{module_name}'
if module_cls in MODULES_THAT_REQUIRE_FBGEMM:
run_test = skipIfNoFBGEMM(run_test)
setattr(TestModuleInit, test_name, run_test)
class TestModuleInit(TestCase):
_ignore_not_implemented_error = False
generate_tests(TestModuleInit, build_constructor_arg_db())
instantiate_device_type_tests(TestModuleInit, globals())
if __name__ == '__main__':
run_tests()