358 lines
15 KiB
Python
358 lines
15 KiB
Python
# Owner(s): ["module: autograd"]
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON, IS_WINDOWS
|
|
import pkgutil
|
|
import torch
|
|
import sys
|
|
from typing import Callable
|
|
import inspect
|
|
import json
|
|
import os
|
|
import unittest
|
|
|
|
|
|
# TODO(jansel): we should remove this workaround once this is fixed:
|
|
# https://github.com/pytorch/pytorch/issues/86619
|
|
NOT_IMPORTED_WHEN_TEST_WRITTEN = {
|
|
"torch.fx.experimental.normalize",
|
|
"torch.fx.experimental.proxy_tensor",
|
|
"torch.fx.experimental.schema_type_annotation",
|
|
"torch.fx.experimental.symbolic_shapes",
|
|
"torch.fx.passes.backends.cudagraphs",
|
|
"torch.fx.passes.infra.partitioner",
|
|
"torch.fx.passes.utils.fuser_utils",
|
|
}
|
|
|
|
|
|
class TestPublicBindings(TestCase):
|
|
def test_no_new_bindings(self):
|
|
"""
|
|
This test aims to stop the introduction of new JIT bindings into torch._C
|
|
whose names do not start with _. Such bindings are made available as
|
|
torch.XXX, which may not be desirable.
|
|
|
|
If your change causes this test to fail, add your new binding to a relevant
|
|
submodule of torch._C, such as torch._C._jit (or other relevant submodule of
|
|
torch._C). If your binding really needs to be available as torch.XXX, add it
|
|
to torch._C and add it to the allowlist below.
|
|
|
|
If you have removed a binding, remove it from the allowlist as well.
|
|
"""
|
|
# This allowlist contains every binding in torch._C that is copied into torch at
|
|
# the time of writing. It was generated with
|
|
#
|
|
# {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
|
#
|
|
torch_C_allowlist_superset = {
|
|
"AggregationType",
|
|
"AliasDb",
|
|
"AnyType",
|
|
"Argument",
|
|
"ArgumentSpec",
|
|
"AwaitType",
|
|
"autocast_decrement_nesting",
|
|
"autocast_increment_nesting",
|
|
"AVG",
|
|
"BenchmarkConfig",
|
|
"BenchmarkExecutionStats",
|
|
"Block",
|
|
"BoolType",
|
|
"BufferDict",
|
|
"StorageBase",
|
|
"CallStack",
|
|
"Capsule",
|
|
"ClassType",
|
|
"clear_autocast_cache",
|
|
"Code",
|
|
"CompilationUnit",
|
|
"CompleteArgumentSpec",
|
|
"ComplexType",
|
|
"ConcreteModuleType",
|
|
"ConcreteModuleTypeBuilder",
|
|
"cpp",
|
|
"CudaBFloat16TensorBase",
|
|
"CudaBoolTensorBase",
|
|
"CudaByteTensorBase",
|
|
"CudaCharTensorBase",
|
|
"CudaComplexDoubleTensorBase",
|
|
"CudaComplexFloatTensorBase",
|
|
"CudaDoubleTensorBase",
|
|
"CudaFloatTensorBase",
|
|
"CudaHalfTensorBase",
|
|
"CudaIntTensorBase",
|
|
"CudaLongTensorBase",
|
|
"CudaShortTensorBase",
|
|
"DeepCopyMemoTable",
|
|
"default_generator",
|
|
"DeserializationStorageContext",
|
|
"device",
|
|
"DeviceObjType",
|
|
"DictType",
|
|
"DisableTorchFunction",
|
|
"DisableTorchFunctionSubclass",
|
|
"DispatchKey",
|
|
"DispatchKeySet",
|
|
"dtype",
|
|
"EnumType",
|
|
"ErrorReport",
|
|
"ExcludeDispatchKeyGuard",
|
|
"ExecutionPlan",
|
|
"FatalError",
|
|
"FileCheck",
|
|
"finfo",
|
|
"FloatType",
|
|
"fork",
|
|
"FunctionSchema",
|
|
"Future",
|
|
"FutureType",
|
|
"Generator",
|
|
"get_autocast_cpu_dtype",
|
|
"get_autocast_ipu_dtype",
|
|
"get_default_dtype",
|
|
"get_num_interop_threads",
|
|
"get_num_threads",
|
|
"Gradient",
|
|
"Graph",
|
|
"GraphExecutorState",
|
|
"has_cuda",
|
|
"has_cudnn",
|
|
"has_lapack",
|
|
"has_mkl",
|
|
"has_mkldnn",
|
|
"has_mps",
|
|
"has_openmp",
|
|
"has_spectral",
|
|
"iinfo",
|
|
"import_ir_module_from_buffer",
|
|
"import_ir_module",
|
|
"InferredType",
|
|
"init_num_threads",
|
|
"InterfaceType",
|
|
"IntType",
|
|
"SymFloatType",
|
|
"SymBoolType",
|
|
"SymIntType",
|
|
"IODescriptor",
|
|
"is_anomaly_enabled",
|
|
"is_anomaly_check_nan_enabled",
|
|
"is_autocast_cache_enabled",
|
|
"is_autocast_cpu_enabled",
|
|
"is_autocast_ipu_enabled",
|
|
"is_autocast_enabled",
|
|
"is_grad_enabled",
|
|
"is_inference_mode_enabled",
|
|
"JITException",
|
|
"layout",
|
|
"ListType",
|
|
"LiteScriptModule",
|
|
"LockingLogger",
|
|
"LoggerBase",
|
|
"memory_format",
|
|
"merge_type_from_type_comment",
|
|
"ModuleDict",
|
|
"Node",
|
|
"NoneType",
|
|
"NoopLogger",
|
|
"NumberType",
|
|
"OperatorInfo",
|
|
"OptionalType",
|
|
"ParameterDict",
|
|
"parse_ir",
|
|
"parse_schema",
|
|
"parse_type_comment",
|
|
"PyObjectType",
|
|
"PyTorchFileReader",
|
|
"PyTorchFileWriter",
|
|
"qscheme",
|
|
"read_vitals",
|
|
"RRefType",
|
|
"ScriptClass",
|
|
"ScriptClassFunction",
|
|
"ScriptDict",
|
|
"ScriptDictIterator",
|
|
"ScriptDictKeyIterator",
|
|
"ScriptList",
|
|
"ScriptListIterator",
|
|
"ScriptFunction",
|
|
"ScriptMethod",
|
|
"ScriptModule",
|
|
"ScriptModuleSerializer",
|
|
"ScriptObject",
|
|
"ScriptObjectProperty",
|
|
"SerializationStorageContext",
|
|
"set_anomaly_enabled",
|
|
"set_autocast_cache_enabled",
|
|
"set_autocast_cpu_dtype",
|
|
"set_autocast_ipu_dtype",
|
|
"set_autocast_cpu_enabled",
|
|
"set_autocast_ipu_enabled",
|
|
"set_autocast_enabled",
|
|
"set_flush_denormal",
|
|
"set_num_interop_threads",
|
|
"set_num_threads",
|
|
"set_vital",
|
|
"Size",
|
|
"StaticModule",
|
|
"Stream",
|
|
"StreamObjType",
|
|
"StringType",
|
|
"SUM",
|
|
"SymFloat",
|
|
"SymInt",
|
|
"TensorType",
|
|
"ThroughputBenchmark",
|
|
"TracingState",
|
|
"TupleType",
|
|
"Type",
|
|
"unify_type_list",
|
|
"UnionType",
|
|
"Use",
|
|
"Value",
|
|
'set_autocast_gpu_dtype',
|
|
'get_autocast_gpu_dtype',
|
|
"vitals_enabled",
|
|
"wait",
|
|
"Tag",
|
|
"set_autocast_xla_enabled",
|
|
"set_autocast_xla_dtype",
|
|
"get_autocast_xla_dtype",
|
|
"is_autocast_xla_enabled",
|
|
}
|
|
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
|
|
|
# Check that the torch._C bindings are all in the allowlist. Since
|
|
# bindings can change based on how PyTorch was compiled (e.g. with/without
|
|
# CUDA), the two may not be an exact match but the bindings should be
|
|
# a subset of the allowlist.
|
|
difference = torch_C_bindings.difference(torch_C_allowlist_superset)
|
|
msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
|
|
self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)
|
|
|
|
# AttributeError: module 'torch.distributed' has no attribute '_shard'
|
|
@unittest.skipIf(IS_WINDOWS or IS_JETSON, "Distributed Attribute Error")
|
|
def test_correct_module_names(self):
|
|
'''
|
|
An API is considered public, if its `__module__` starts with `torch.`
|
|
and there is no name in `__module__` or the object itself that starts with “_”.
|
|
Each public package should either:
|
|
- (preferred) Define `__all__` and all callables and classes in there must have their
|
|
`__module__` start with the current submodule's path. Things not in `__all__` should
|
|
NOT have their `__module__` start with the current submodule.
|
|
- (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
|
|
`__module__` that start with the current submodule.
|
|
'''
|
|
failure_list = []
|
|
with open(os.path.join(os.path.dirname(__file__), 'allowlist_for_publicAPI.json')) as json_file:
|
|
# no new entries should be added to this allow_dict.
|
|
# New APIs must follow the public API guidelines.
|
|
allow_dict = json.load(json_file)
|
|
# Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
|
|
# we are adding the entries for the migrated modules here from the original
|
|
# locations.
|
|
for modname in allow_dict["being_migrated"]:
|
|
if modname in allow_dict:
|
|
allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[modname]
|
|
|
|
def test_module(modname):
|
|
split_strs = modname.split('.')
|
|
mod = sys.modules.get(modname)
|
|
for elem in split_strs:
|
|
if elem.startswith("_"):
|
|
return
|
|
|
|
# verifies that each public API has the correct module name and naming semantics
|
|
def check_one_element(elem, modname, mod, *, is_public, is_all):
|
|
obj = getattr(mod, elem)
|
|
if not (isinstance(obj, Callable) or inspect.isclass(obj)):
|
|
return
|
|
elem_module = getattr(obj, '__module__', None)
|
|
# Only used for nice error message below
|
|
why_not_looks_public = ""
|
|
if elem_module is None:
|
|
why_not_looks_public = "because it does not have a `__module__` attribute"
|
|
# If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
|
|
# the module's starting package would be referred to as the new location even
|
|
# if there is a "from foo import a" inside the "bar.py".
|
|
modname = allow_dict["being_migrated"].get(modname, modname)
|
|
elem_modname_starts_with_mod = elem_module is not None and \
|
|
elem_module.startswith(modname) and \
|
|
'._' not in elem_module
|
|
if not why_not_looks_public and not elem_modname_starts_with_mod:
|
|
why_not_looks_public = f"because its `__module__` attribute (`{elem_module}`) is not within the " \
|
|
f"torch library or does not start with the submodule where it is defined (`{modname}`)"
|
|
# elem's name must NOT begin with an `_` and it's module name
|
|
# SHOULD start with it's current module since it's a public API
|
|
looks_public = not elem.startswith('_') and elem_modname_starts_with_mod
|
|
if not why_not_looks_public and not looks_public:
|
|
why_not_looks_public = f"because it starts with `_` (`{elem}`)"
|
|
|
|
if is_public != looks_public:
|
|
if modname in NOT_IMPORTED_WHEN_TEST_WRITTEN:
|
|
return
|
|
if modname in allow_dict and elem in allow_dict[modname]:
|
|
return
|
|
|
|
if is_public:
|
|
why_is_public = f"it is inside the module's (`{modname}`) `__all__`" if is_all else \
|
|
"it is an attribute that does not start with `_` on a module that " \
|
|
"does not have `__all__` defined"
|
|
fix_is_public = f"remove it from the modules's (`{modname}`) `__all__`" if is_all else \
|
|
f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
|
|
else:
|
|
assert is_all
|
|
why_is_public = f"it is not inside the module's (`{modname}`) `__all__`"
|
|
fix_is_public = f"add it from the modules's (`{modname}`) `__all__`"
|
|
|
|
if looks_public:
|
|
why_looks_public = "it does look public because it follows the rules from the doc above " \
|
|
"(does not start with `_` and has a proper `__module__`)."
|
|
fix_looks_public = "make its name start with `_`"
|
|
else:
|
|
why_looks_public = why_not_looks_public
|
|
if not elem_modname_starts_with_mod:
|
|
fix_looks_public = "make sure the `__module__` is properly set and points to a submodule "\
|
|
f"of `{modname}`"
|
|
else:
|
|
fix_looks_public = "remove the `_` at the beginning of the name"
|
|
|
|
failure_list.append(f"# {modname}.{elem}:")
|
|
is_public_str = "" if is_public else " NOT"
|
|
failure_list.append(f" - Is{is_public_str} public: {why_is_public}")
|
|
looks_public_str = "" if looks_public else " NOT"
|
|
failure_list.append(f" - Does{looks_public_str} look public: {why_looks_public}")
|
|
# Swap the str below to avoid having to create the NOT again
|
|
failure_list.append(" - You can do either of these two things to fix this problem:")
|
|
failure_list.append(f" - To make it{looks_public_str} public: {fix_is_public}")
|
|
failure_list.append(f" - To make it{is_public_str} look public: {fix_looks_public}")
|
|
|
|
|
|
if hasattr(mod, '__all__'):
|
|
public_api = mod.__all__
|
|
all_api = dir(mod)
|
|
for elem in all_api:
|
|
check_one_element(elem, modname, mod, is_public=elem in public_api, is_all=True)
|
|
|
|
else:
|
|
all_api = dir(mod)
|
|
for elem in all_api:
|
|
if not elem.startswith('_'):
|
|
check_one_element(elem, modname, mod, is_public=True, is_all=False)
|
|
for _, modname, ispkg in pkgutil.walk_packages(path=torch.__path__, prefix=torch.__name__ + '.'):
|
|
test_module(modname)
|
|
|
|
test_module('torch')
|
|
|
|
msg = "All the APIs below do not meet our guidelines for public API from " \
|
|
"https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
|
|
msg += "Make sure that everything that is public is expected (in particular that the module " \
|
|
"has a properly populated `__all__` attribute) and that everything that is supposed to be public " \
|
|
"does look public (it does not start with `_` and has a `__module__` that is properly populated)."
|
|
msg += "\n\nFull list:\n"
|
|
msg += "\n".join(map(str, failure_list))
|
|
|
|
# empty lists are considered false in python
|
|
self.assertTrue(not failure_list, msg)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|