113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
from io import BytesIO
|
|
|
|
from torch.package import PackageExporter, PackageImporter
|
|
from torch.package._mangling import (
|
|
demangle,
|
|
get_mangle_prefix,
|
|
is_mangled,
|
|
PackageMangler,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
class TestMangling(PackageTestCase):
|
|
def test_unique_manglers(self):
|
|
"""
|
|
Each mangler instance should generate a unique mangled name for a given input.
|
|
"""
|
|
a = PackageMangler()
|
|
b = PackageMangler()
|
|
self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar"))
|
|
|
|
def test_mangler_is_consistent(self):
|
|
"""
|
|
Mangling the same name twice should produce the same result.
|
|
"""
|
|
a = PackageMangler()
|
|
self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def"))
|
|
|
|
def test_roundtrip_mangling(self):
|
|
a = PackageMangler()
|
|
self.assertEqual("foo", demangle(a.mangle("foo")))
|
|
|
|
def test_is_mangled(self):
|
|
a = PackageMangler()
|
|
b = PackageMangler()
|
|
self.assertTrue(is_mangled(a.mangle("foo.bar")))
|
|
self.assertTrue(is_mangled(b.mangle("foo.bar")))
|
|
|
|
self.assertFalse(is_mangled("foo.bar"))
|
|
self.assertFalse(is_mangled(demangle(a.mangle("foo.bar"))))
|
|
|
|
def test_demangler_multiple_manglers(self):
|
|
"""
|
|
PackageDemangler should be able to demangle name generated by any PackageMangler.
|
|
"""
|
|
a = PackageMangler()
|
|
b = PackageMangler()
|
|
|
|
self.assertEqual("foo.bar", demangle(a.mangle("foo.bar")))
|
|
self.assertEqual("bar.foo", demangle(b.mangle("bar.foo")))
|
|
|
|
def test_mangle_empty_errors(self):
|
|
a = PackageMangler()
|
|
with self.assertRaises(AssertionError):
|
|
a.mangle("")
|
|
|
|
def test_demangle_base(self):
|
|
"""
|
|
Demangling a mangle parent directly should currently return an empty string.
|
|
"""
|
|
a = PackageMangler()
|
|
mangled = a.mangle("foo")
|
|
mangle_parent = mangled.partition(".")[0]
|
|
self.assertEqual("", demangle(mangle_parent))
|
|
|
|
def test_mangle_prefix(self):
|
|
a = PackageMangler()
|
|
mangled = a.mangle("foo.bar")
|
|
mangle_prefix = get_mangle_prefix(mangled)
|
|
self.assertEqual(mangle_prefix + "." + "foo.bar", mangled)
|
|
|
|
def test_unique_module_names(self):
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
f1 = BytesIO()
|
|
with PackageExporter(f1) as pe:
|
|
pe.intern("**")
|
|
pe.save_pickle("obj", "obj.pkl", obj2)
|
|
f1.seek(0)
|
|
importer1 = PackageImporter(f1)
|
|
loaded1 = importer1.load_pickle("obj", "obj.pkl")
|
|
f1.seek(0)
|
|
importer2 = PackageImporter(f1)
|
|
loaded2 = importer2.load_pickle("obj", "obj.pkl")
|
|
|
|
# Modules from loaded packages should not shadow the names of modules.
|
|
# See mangling.md for more info.
|
|
self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__)
|
|
self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__)
|
|
|
|
def test_package_mangler(self):
|
|
a = PackageMangler()
|
|
b = PackageMangler()
|
|
a_mangled = a.mangle("foo.bar")
|
|
# Since `a` mangled this string, it should demangle properly.
|
|
self.assertEqual(a.demangle(a_mangled), "foo.bar")
|
|
# Since `b` did not mangle this string, demangling should leave it alone.
|
|
self.assertEqual(b.demangle(a_mangled), a_mangled)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|