199 lines
7.2 KiB
Python
199 lines
7.2 KiB
Python
# Owner(s): ["oncall: package/deploy"]
|
|
|
|
from io import BytesIO
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
from torch.package import PackageExporter, PackageImporter, sys_importer
|
|
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests
|
|
|
|
try:
|
|
from torchvision.models import resnet18
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase
|
|
|
|
|
|
@skipIf(True, "Does not work with recent torchvision, see https://github.com/pytorch/pytorch/issues/81115")
|
|
@skipIfNoTorchVision
|
|
class ModelTest(PackageTestCase):
|
|
"""End-to-end tests packaging an entire model."""
|
|
|
|
@skipIf(
|
|
IS_FBCODE or IS_SANDCASTLE,
|
|
"Tests that use temporary files are disabled in fbcode",
|
|
)
|
|
def test_resnet(self):
|
|
resnet = resnet18()
|
|
|
|
f1 = self.temp()
|
|
|
|
# create a package that will save it along with its code
|
|
with PackageExporter(f1) as e:
|
|
# put the pickled resnet in the package, by default
|
|
# this will also save all the code files references by
|
|
# the objects in the pickle
|
|
e.intern("**")
|
|
e.save_pickle("model", "model.pkl", resnet)
|
|
|
|
# we can now load the saved model
|
|
i = PackageImporter(f1)
|
|
r2 = i.load_pickle("model", "model.pkl")
|
|
|
|
# test that it works
|
|
input = torch.rand(1, 3, 224, 224)
|
|
ref = resnet(input)
|
|
self.assertEqual(r2(input), ref)
|
|
|
|
# functions exist also to get at the private modules in each package
|
|
torchvision = i.import_module("torchvision")
|
|
|
|
f2 = BytesIO()
|
|
# if we are doing transfer learning we might want to re-save
|
|
# things that were loaded from a package.
|
|
# We need to tell the exporter about any modules that
|
|
# came from imported packages so that it can resolve
|
|
# class names like torchvision.models.resnet.ResNet
|
|
# to their source code.
|
|
with PackageExporter(f2, importer=(i, sys_importer)) as e:
|
|
# e.importers is a list of module importing functions
|
|
# that by default contains importlib.import_module.
|
|
# it is searched in order until the first success and
|
|
# that module is taken to be what torchvision.models.resnet
|
|
# should be in this code package. In the case of name collisions,
|
|
# such as trying to save a ResNet from two different packages,
|
|
# we take the first thing found in the path, so only ResNet objects from
|
|
# one importer will work. This avoids a bunch of name mangling in
|
|
# the source code. If you need to actually mix ResNet objects,
|
|
# we suggest reconstructing the model objects using code from a single package
|
|
# using functions like save_state_dict and load_state_dict to transfer state
|
|
# to the correct code objects.
|
|
e.intern("**")
|
|
e.save_pickle("model", "model.pkl", r2)
|
|
|
|
f2.seek(0)
|
|
|
|
i2 = PackageImporter(f2)
|
|
r3 = i2.load_pickle("model", "model.pkl")
|
|
self.assertEqual(r3(input), ref)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_model_save(self):
|
|
|
|
# This example shows how you might package a model
|
|
# so that the creator of the model has flexibility about
|
|
# how they want to save it but the 'server' can always
|
|
# use the same API to load the package.
|
|
|
|
# The convension is for each model to provide a
|
|
# 'model' package with a 'load' function that actual
|
|
# reads the model out of the archive.
|
|
|
|
# How the load function is implemented is up to the
|
|
# the packager.
|
|
|
|
# get our normal torchvision resnet
|
|
resnet = resnet18()
|
|
|
|
f1 = BytesIO()
|
|
# Option 1: save by pickling the whole model
|
|
# + single-line, similar to torch.jit.save
|
|
# - more difficult to edit the code after the model is created
|
|
with PackageExporter(f1) as e:
|
|
e.intern("**")
|
|
e.save_pickle("model", "pickled", resnet)
|
|
# note that this source is the same for all models in this approach
|
|
# so it can be made part of an API that just takes the model and
|
|
# packages it with this source.
|
|
src = dedent(
|
|
"""\
|
|
import importlib
|
|
import torch_package_importer as resources
|
|
|
|
# server knows to call model.load() to get the model,
|
|
# maybe in the future it passes options as arguments by convension
|
|
def load():
|
|
return resources.load_pickle('model', 'pickled')
|
|
"""
|
|
)
|
|
e.save_source_string("model", src, is_package=True)
|
|
|
|
f2 = BytesIO()
|
|
# Option 2: save with state dict
|
|
# - more code to write to save/load the model
|
|
# + but this code can be edited later to adjust adapt the model later
|
|
with PackageExporter(f2) as e:
|
|
e.intern("**")
|
|
e.save_pickle("model", "state_dict", resnet.state_dict())
|
|
src = dedent(
|
|
"""\
|
|
import importlib
|
|
import torch_package_importer as resources
|
|
|
|
from torchvision.models.resnet import resnet18
|
|
def load():
|
|
# if you want, you can later edit how resnet is constructed here
|
|
# to edit the model in the package, while still loading the original
|
|
# state dict weights
|
|
r = resnet18()
|
|
state_dict = resources.load_pickle('model', 'state_dict')
|
|
r.load_state_dict(state_dict)
|
|
return r
|
|
"""
|
|
)
|
|
e.save_source_string("model", src, is_package=True)
|
|
|
|
# regardless of how we chose to package, we can now use the model in a server in the same way
|
|
input = torch.rand(1, 3, 224, 224)
|
|
results = []
|
|
for m in [f1, f2]:
|
|
m.seek(0)
|
|
importer = PackageImporter(m)
|
|
the_model = importer.import_module("model").load()
|
|
r = the_model(input)
|
|
results.append(r)
|
|
|
|
self.assertEqual(*results)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_script_resnet(self):
|
|
resnet = resnet18()
|
|
|
|
f1 = BytesIO()
|
|
# Option 1: save by pickling the whole model
|
|
# + single-line, similar to torch.jit.save
|
|
# - more difficult to edit the code after the model is created
|
|
with PackageExporter(f1) as e:
|
|
e.intern("**")
|
|
e.save_pickle("model", "pickled", resnet)
|
|
|
|
f1.seek(0)
|
|
|
|
i = PackageImporter(f1)
|
|
loaded = i.load_pickle("model", "pickled")
|
|
|
|
# Model should script successfully.
|
|
scripted = torch.jit.script(loaded)
|
|
|
|
# Scripted model should save and load successfully.
|
|
f2 = BytesIO()
|
|
torch.jit.save(scripted, f2)
|
|
f2.seek(0)
|
|
loaded = torch.jit.load(f2)
|
|
|
|
input = torch.rand(1, 3, 224, 224)
|
|
self.assertEqual(loaded(input), resnet(input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|