286 lines
10 KiB
Python
286 lines
10 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import os
|
|
import zipfile
|
|
from sys import version_info
|
|
from tempfile import TemporaryDirectory
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
from torch.package import PackageExporter, PackageImporter
|
|
from torch.testing._internal.common_utils import (
|
|
IS_FBCODE,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
)
|
|
|
|
try:
|
|
from torchvision.models import resnet18
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
from pathlib import Path
|
|
|
|
packaging_directory = Path(__file__).parent
|
|
|
|
|
|
@skipIf(
|
|
IS_FBCODE or IS_SANDCASTLE or IS_WINDOWS,
|
|
"Tests that use temporary files are disabled in fbcode",
|
|
)
|
|
class DirectoryReaderTest(PackageTestCase):
|
|
"""Tests use of DirectoryReader as accessor for opened packages."""
|
|
|
|
@skipIfNoTorchVision
|
|
@skipIf(True, "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115")
|
|
def test_loading_pickle(self):
|
|
"""
|
|
Test basic saving and loading of modules and pickles from a DirectoryReader.
|
|
"""
|
|
resnet = resnet18()
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.intern("**")
|
|
e.save_pickle("model", "model.pkl", resnet)
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
dir_mod = importer.load_pickle("model", "model.pkl")
|
|
input = torch.rand(1, 3, 224, 224)
|
|
self.assertEqual(dir_mod(input), resnet(input))
|
|
|
|
def test_loading_module(self):
|
|
"""
|
|
Test basic saving and loading of a packages from a DirectoryReader.
|
|
"""
|
|
import package_a
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.save_module("package_a")
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
dir_mod = dir_importer.import_module("package_a")
|
|
self.assertEqual(dir_mod.result, package_a.result)
|
|
|
|
def test_loading_has_record(self):
|
|
"""
|
|
Test DirectoryReader's has_record().
|
|
"""
|
|
import package_a # noqa: F401
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.save_module("package_a")
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
|
|
self.assertFalse(dir_importer.zip_reader.has_record("package_a"))
|
|
|
|
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
|
|
def test_resource_reader(self):
|
|
"""Tests DirectoryReader as the base for get_resource_reader."""
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as pe:
|
|
# Layout looks like:
|
|
# package
|
|
# ├── one/
|
|
# │ ├── a.txt
|
|
# │ ├── b.txt
|
|
# │ ├── c.txt
|
|
# │ └── three/
|
|
# │ ├── d.txt
|
|
# │ └── e.txt
|
|
# └── two/
|
|
# ├── f.txt
|
|
# └── g.txt
|
|
pe.save_text("one", "a.txt", "hello, a!")
|
|
pe.save_text("one", "b.txt", "hello, b!")
|
|
pe.save_text("one", "c.txt", "hello, c!")
|
|
|
|
pe.save_text("one.three", "d.txt", "hello, d!")
|
|
pe.save_text("one.three", "e.txt", "hello, e!")
|
|
|
|
pe.save_text("two", "f.txt", "hello, f!")
|
|
pe.save_text("two", "g.txt", "hello, g!")
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
reader_one = importer.get_resource_reader("one")
|
|
|
|
# Different behavior from still zipped archives
|
|
resource_path = os.path.join(
|
|
Path(temp_dir), Path(filename).name, "one", "a.txt"
|
|
)
|
|
self.assertEqual(reader_one.resource_path("a.txt"), resource_path)
|
|
|
|
self.assertTrue(reader_one.is_resource("a.txt"))
|
|
self.assertEqual(
|
|
reader_one.open_resource("a.txt").getbuffer(), b"hello, a!"
|
|
)
|
|
self.assertFalse(reader_one.is_resource("three"))
|
|
reader_one_contents = list(reader_one.contents())
|
|
reader_one_contents.sort()
|
|
self.assertSequenceEqual(
|
|
reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
|
|
)
|
|
|
|
reader_two = importer.get_resource_reader("two")
|
|
self.assertTrue(reader_two.is_resource("f.txt"))
|
|
self.assertEqual(
|
|
reader_two.open_resource("f.txt").getbuffer(), b"hello, f!"
|
|
)
|
|
reader_two_contents = list(reader_two.contents())
|
|
reader_two_contents.sort()
|
|
self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
|
|
|
|
reader_one_three = importer.get_resource_reader("one.three")
|
|
self.assertTrue(reader_one_three.is_resource("d.txt"))
|
|
self.assertEqual(
|
|
reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
|
|
)
|
|
reader_one_three_contents = list(reader_one_three.contents())
|
|
reader_one_three_contents.sort()
|
|
self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"])
|
|
|
|
self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
|
|
|
|
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
|
|
def test_package_resource_access(self):
|
|
"""Packaged modules should be able to use the importlib.resources API to access
|
|
resources saved in the package.
|
|
"""
|
|
mod_src = dedent(
|
|
"""\
|
|
import importlib.resources
|
|
import my_cool_resources
|
|
|
|
def secret_message():
|
|
return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
|
|
"""
|
|
)
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as pe:
|
|
pe.save_source_string("foo.bar", mod_src)
|
|
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
self.assertEqual(
|
|
dir_importer.import_module("foo.bar").secret_message(),
|
|
"my sekrit plays",
|
|
)
|
|
|
|
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
|
|
def test_importer_access(self):
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as he:
|
|
he.save_text("main", "main", "my string")
|
|
he.save_binary("main", "main_binary", b"my string")
|
|
src = dedent(
|
|
"""\
|
|
import importlib
|
|
import torch_package_importer as resources
|
|
|
|
t = resources.load_text('main', 'main')
|
|
b = resources.load_binary('main', 'main_binary')
|
|
"""
|
|
)
|
|
he.save_source_string("main", src, is_package=True)
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
m = dir_importer.import_module("main")
|
|
self.assertEqual(m.t, "my string")
|
|
self.assertEqual(m.b, b"my string")
|
|
|
|
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
|
|
def test_resource_access_by_path(self):
|
|
"""
|
|
Tests that packaged code can used importlib.resources.path.
|
|
"""
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.save_binary("string_module", "my_string", b"my string")
|
|
src = dedent(
|
|
"""\
|
|
import importlib.resources
|
|
import string_module
|
|
|
|
with importlib.resources.path(string_module, 'my_string') as path:
|
|
with open(path, mode='r', encoding='utf-8') as f:
|
|
s = f.read()
|
|
"""
|
|
)
|
|
e.save_source_string("main", src, is_package=True)
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
m = dir_importer.import_module("main")
|
|
self.assertEqual(m.s, "my string")
|
|
|
|
def test_scriptobject_failure_message(self):
|
|
"""
|
|
Test basic saving and loading of a ScriptModule in a directory.
|
|
Currently not supported.
|
|
"""
|
|
from package_a.test_module import ModWithTensor
|
|
|
|
scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename) as e:
|
|
e.save_pickle("res", "mod.pkl", scripted_mod)
|
|
|
|
zip_file = zipfile.ZipFile(filename, "r")
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Loading ScriptObjects from a PackageImporter created from a "
|
|
"directory is not supported. Use a package archive file instead.",
|
|
):
|
|
with TemporaryDirectory() as temp_dir:
|
|
zip_file.extractall(path=temp_dir)
|
|
dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
|
|
dir_mod = dir_importer.load_pickle("res", "mod.pkl")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|