384 lines
13 KiB
Python
384 lines
13 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
import importlib
|
|
from io import BytesIO
|
|
from sys import version_info
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
import torch.nn
|
|
|
|
from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
|
|
from torch.package.package_exporter import PackagingError
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
class TestDependencyAPI(PackageTestCase):
|
|
"""Dependency management API tests.
|
|
- mock()
|
|
- extern()
|
|
- deny()
|
|
"""
|
|
|
|
def test_extern(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.extern(["package_a.subpackage", "module_a"])
|
|
he.save_source_string("foo", "import package_a.subpackage; import module_a")
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
import module_a
|
|
import package_a.subpackage
|
|
|
|
module_a_im = hi.import_module("module_a")
|
|
hi.import_module("package_a.subpackage")
|
|
package_a_im = hi.import_module("package_a")
|
|
|
|
self.assertIs(module_a, module_a_im)
|
|
self.assertIsNot(package_a, package_a_im)
|
|
self.assertIs(package_a.subpackage, package_a_im.subpackage)
|
|
|
|
def test_extern_glob(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.extern(["package_a.*", "module_*"])
|
|
he.save_module("package_a")
|
|
he.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
import module_a
|
|
import package_a.subpackage
|
|
|
|
module_a_im = hi.import_module("module_a")
|
|
hi.import_module("package_a.subpackage")
|
|
package_a_im = hi.import_module("package_a")
|
|
|
|
self.assertIs(module_a, module_a_im)
|
|
self.assertIsNot(package_a, package_a_im)
|
|
self.assertIs(package_a.subpackage, package_a_im.subpackage)
|
|
|
|
def test_extern_glob_allow_empty(self):
|
|
"""
|
|
Test that an error is thrown when a extern glob is specified with allow_empty=True
|
|
and no matching module is required during packaging.
|
|
"""
|
|
import package_a.subpackage # noqa: F401
|
|
|
|
buffer = BytesIO()
|
|
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.extern(include=["package_b.*"], allow_empty=False)
|
|
exporter.save_module("package_a.subpackage")
|
|
|
|
def test_deny(self):
|
|
"""
|
|
Test marking packages as "deny" during export.
|
|
"""
|
|
buffer = BytesIO()
|
|
|
|
with self.assertRaisesRegex(PackagingError, "denied"):
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.deny(["package_a.subpackage", "module_a"])
|
|
exporter.save_source_string("foo", "import package_a.subpackage")
|
|
|
|
def test_deny_glob(self):
|
|
"""
|
|
Test marking packages as "deny" using globs instead of package names.
|
|
"""
|
|
buffer = BytesIO()
|
|
with self.assertRaises(PackagingError):
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.deny(["package_a.*", "module_*"])
|
|
exporter.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_mock(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.mock(["package_a.subpackage", "module_a"])
|
|
# Import something that dependso n package_a.subpackage
|
|
he.save_source_string("foo", "import package_a.subpackage")
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
import package_a.subpackage
|
|
|
|
_ = package_a.subpackage
|
|
import module_a
|
|
|
|
_ = module_a
|
|
|
|
m = hi.import_module("package_a.subpackage")
|
|
r = m.result
|
|
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
|
|
r()
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_mock_glob(self):
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.mock(["package_a.*", "module*"])
|
|
he.save_module("package_a")
|
|
he.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
buffer.seek(0)
|
|
hi = PackageImporter(buffer)
|
|
import package_a.subpackage
|
|
|
|
_ = package_a.subpackage
|
|
import module_a
|
|
|
|
_ = module_a
|
|
|
|
m = hi.import_module("package_a.subpackage")
|
|
r = m.result
|
|
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
|
|
r()
|
|
|
|
def test_mock_glob_allow_empty(self):
|
|
"""
|
|
Test that an error is thrown when a mock glob is specified with allow_empty=True
|
|
and no matching module is required during packaging.
|
|
"""
|
|
import package_a.subpackage # noqa: F401
|
|
|
|
buffer = BytesIO()
|
|
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.mock(include=["package_b.*"], allow_empty=False)
|
|
exporter.save_module("package_a.subpackage")
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_pickle_mocked(self):
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
buffer = BytesIO()
|
|
with self.assertRaises(PackagingError):
|
|
with PackageExporter(buffer) as he:
|
|
he.mock(include="package_a.subpackage")
|
|
he.intern("**")
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_pickle_mocked_all(self):
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.intern(include="package_a.**")
|
|
he.mock("**")
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
def test_allow_empty_with_error(self):
|
|
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
|
|
buffer = BytesIO()
|
|
with self.assertRaises(ModuleNotFoundError):
|
|
with PackageExporter(buffer) as pe:
|
|
# Even though we did not extern a module that matches this
|
|
# pattern, we want to show the save_module error, not the allow_empty error.
|
|
|
|
pe.extern("foo", allow_empty=False)
|
|
pe.save_module("aodoifjodisfj") # will error
|
|
|
|
# we never get here, so technically the allow_empty check
|
|
# should raise an error. However, the error above is more
|
|
# informative to what's actually going wrong with packaging.
|
|
pe.save_source_string("bar", "import foo\n")
|
|
|
|
def test_implicit_intern(self):
|
|
"""The save_module APIs should implicitly intern the module being saved."""
|
|
import package_a # noqa: F401
|
|
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as he:
|
|
he.save_module("package_a")
|
|
|
|
def test_intern_error(self):
|
|
"""Failure to handle all dependencies should lead to an error."""
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
buffer = BytesIO()
|
|
|
|
with self.assertRaises(PackagingError) as e:
|
|
with PackageExporter(buffer) as he:
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
self.assertEqual(
|
|
str(e.exception),
|
|
dedent(
|
|
"""
|
|
* Module did not match against any action pattern. Extern, mock, or intern it.
|
|
package_a
|
|
package_a.subpackage
|
|
|
|
Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
|
|
"""
|
|
),
|
|
)
|
|
|
|
# Interning all dependencies should work
|
|
with PackageExporter(buffer) as he:
|
|
he.intern(["package_a", "package_a.subpackage"])
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
@skipIf(IS_WINDOWS, "extension modules have a different file extension on windows")
|
|
def test_broken_dependency(self):
|
|
"""A unpackageable dependency should raise a PackagingError."""
|
|
|
|
def create_module(name):
|
|
spec = importlib.machinery.ModuleSpec(name, self, is_package=False) # type: ignore[arg-type]
|
|
module = importlib.util.module_from_spec(spec)
|
|
ns = module.__dict__
|
|
ns["__spec__"] = spec
|
|
ns["__loader__"] = self
|
|
ns["__file__"] = f"{name}.so"
|
|
ns["__cached__"] = None
|
|
return module
|
|
|
|
class BrokenImporter(Importer):
|
|
def __init__(self):
|
|
self.modules = {
|
|
"foo": create_module("foo"),
|
|
"bar": create_module("bar"),
|
|
}
|
|
|
|
def import_module(self, module_name):
|
|
return self.modules[module_name]
|
|
|
|
buffer = BytesIO()
|
|
|
|
with self.assertRaises(PackagingError) as e:
|
|
with PackageExporter(buffer, importer=BrokenImporter()) as exporter:
|
|
exporter.intern(["foo", "bar"])
|
|
exporter.save_source_string("my_module", "import foo; import bar")
|
|
|
|
self.assertEqual(
|
|
str(e.exception),
|
|
dedent(
|
|
"""
|
|
* Module is a C extension module. torch.package supports Python modules only.
|
|
foo
|
|
bar
|
|
|
|
Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
|
|
"""
|
|
),
|
|
)
|
|
|
|
def test_invalid_import(self):
|
|
"""An incorrectly-formed import should raise a PackagingError."""
|
|
buffer = BytesIO()
|
|
with self.assertRaises(PackagingError) as e:
|
|
with PackageExporter(buffer) as exporter:
|
|
# This import will fail to load.
|
|
exporter.save_source_string("foo", "from ........ import lol")
|
|
|
|
self.assertEqual(
|
|
str(e.exception),
|
|
dedent(
|
|
"""
|
|
* Dependency resolution failed.
|
|
foo
|
|
Context: attempted relative import beyond top-level package
|
|
|
|
Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
|
|
"""
|
|
),
|
|
)
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_repackage_mocked_module(self):
|
|
"""Re-packaging a package that contains a mocked module should work correctly."""
|
|
buffer = BytesIO()
|
|
with PackageExporter(buffer) as exporter:
|
|
exporter.mock("package_a")
|
|
exporter.save_source_string("foo", "import package_a")
|
|
|
|
buffer.seek(0)
|
|
importer = PackageImporter(buffer)
|
|
foo = importer.import_module("foo")
|
|
|
|
# "package_a" should be mocked out.
|
|
with self.assertRaises(NotImplementedError):
|
|
foo.package_a.get_something()
|
|
|
|
# Re-package the model, but intern the previously-mocked module and mock
|
|
# everything else.
|
|
buffer2 = BytesIO()
|
|
with PackageExporter(buffer2, importer=importer) as exporter:
|
|
exporter.intern("package_a")
|
|
exporter.mock("**")
|
|
exporter.save_source_string("foo", "import package_a")
|
|
|
|
buffer2.seek(0)
|
|
importer2 = PackageImporter(buffer2)
|
|
foo2 = importer2.import_module("foo")
|
|
|
|
# "package_a" should still be mocked out.
|
|
with self.assertRaises(NotImplementedError):
|
|
foo2.package_a.get_something()
|
|
|
|
def test_externing_c_extension(self):
|
|
"""Externing c extensions modules should allow us to still access them especially those found in torch._C."""
|
|
|
|
buffer = BytesIO()
|
|
# The C extension module in question is F.gelu which comes from torch._C._nn
|
|
model = torch.nn.TransformerEncoderLayer(
|
|
d_model=64,
|
|
nhead=2,
|
|
dim_feedforward=64,
|
|
dropout=1.0,
|
|
batch_first=True,
|
|
activation="gelu",
|
|
norm_first=True,
|
|
)
|
|
with PackageExporter(buffer) as e:
|
|
e.extern("torch.**")
|
|
e.intern("**")
|
|
|
|
e.save_pickle("model", "model.pkl", model)
|
|
buffer.seek(0)
|
|
imp = PackageImporter(buffer)
|
|
imp.load_pickle("model", "model.pkl")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|