# Owner(s): ["oncall: mobile"] import unittest import io import tempfile import torch import torch.utils.show_pickle from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): return x * self.weight m = torch.jit.script(MyCoolModule(torch.tensor([2.0]))) with tempfile.NamedTemporaryFile() as tmp: torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") if __name__ == '__main__': run_tests()