import contextlib import importlib import sys import torch import torch.testing from torch.testing._internal.common_utils import ( IS_WINDOWS, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO, TestCase as TorchTestCase, ) from . import config, reset, utils def run_tests(needs=()): from torch.testing._internal.common_utils import run_tests if ( TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF or sys.version_info >= (3, 12) ): return # skip testing if isinstance(needs, str): needs = (needs,) for need in needs: if need == "cuda" and not torch.cuda.is_available(): return else: try: importlib.import_module(need) except ImportError: return run_tests() class TestCase(TorchTestCase): @classmethod def tearDownClass(cls): cls._exit_stack.close() super().tearDownClass() @classmethod def setUpClass(cls): super().setUpClass() cls._exit_stack = contextlib.ExitStack() cls._exit_stack.enter_context( config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False), ) def setUp(self): super().setUp() reset() utils.counters.clear() def tearDown(self): for k, v in utils.counters.items(): print(k, v.most_common()) reset() utils.counters.clear() super().tearDown()