pytorch/test/package/test_package_fx.py

192 lines
5.9 KiB
Python

# Owner(s): ["oncall: package/deploy"]
from io import BytesIO
import torch
from torch.fx import Graph, GraphModule, symbolic_trace
from torch.package import (
ObjMismatchError,
PackageExporter,
PackageImporter,
sys_importer,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
torch.fx.wrap("len")
# Do it twice to make sure it doesn't affect anything
torch.fx.wrap("len")
class TestPackageFX(PackageTestCase):
"""Tests for compatibility with FX."""
def test_package_fx_simple(self):
class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)
st = SimpleTest()
traced = symbolic_trace(st)
f = BytesIO()
with PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)
pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))
def test_package_then_fx(self):
from package_a.test_module import SimpleTest
model = SimpleTest()
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", model)
f.seek(0)
pi = PackageImporter(f)
loaded = pi.load_pickle("model", "model.pkl")
traced = symbolic_trace(loaded)
input = torch.rand(2, 3)
self.assertEqual(loaded(input), traced(input))
def test_package_fx_package(self):
from package_a.test_module import SimpleTest
model = SimpleTest()
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", model)
f.seek(0)
pi = PackageImporter(f)
loaded = pi.load_pickle("model", "model.pkl")
traced = symbolic_trace(loaded)
# re-save the package exporter
f2 = BytesIO()
# This should fail, because we are referencing some globals that are
# only in the package.
with self.assertRaises(ObjMismatchError):
with PackageExporter(f2) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", traced)
f2.seek(0)
with PackageExporter(f2, importer=(pi, sys_importer)) as pe:
# Make the package available to the exporter's environment.
pe.intern("**")
pe.save_pickle("model", "model.pkl", traced)
f2.seek(0)
pi2 = PackageImporter(f2)
loaded2 = pi2.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded(input), loaded2(input))
def test_package_fx_with_imports(self):
import package_a.subpackage
# Manually construct a graph that invokes a leaf function
graph = Graph()
a = graph.placeholder("x")
b = graph.placeholder("y")
c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
d = graph.call_function(torch.sin, (c,))
graph.output(d)
gm = GraphModule(torch.nn.Module(), graph)
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", gm)
f.seek(0)
pi = PackageImporter(f)
loaded_gm = pi.load_pickle("model", "model.pkl")
input_x = torch.rand(2, 3)
input_y = torch.rand(2, 3)
self.assertTrue(
torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y))
)
# Check that the packaged version of the leaf_function dependency is
# not the same as in the outer env.
packaged_dependency = pi.import_module("package_a.subpackage")
self.assertTrue(packaged_dependency is not package_a.subpackage)
def test_package_fx_custom_tracer(self):
from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
from package_a.test_module import ModWithTwoSubmodsAndTensor, SimpleTest
class SpecialGraphModule(torch.fx.GraphModule):
def __init__(self, root, graph, info):
super().__init__(root, graph)
self.info = info
sub_module = SimpleTest()
module = ModWithTwoSubmodsAndTensor(
torch.ones(3),
sub_module,
sub_module,
)
tracer = TestAllLeafModulesTracer()
graph = tracer.trace(module)
self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
gm = SpecialGraphModule(module, graph, "secret")
self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
f = BytesIO()
with PackageExporter(f) as pe:
pe.intern("**")
pe.save_pickle("model", "model.pkl", gm)
f.seek(0)
pi = PackageImporter(f)
loaded_gm = pi.load_pickle("model", "model.pkl")
self.assertEqual(
type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
)
self.assertEqual(loaded_gm.info, "secret")
input_x = torch.randn(3)
self.assertEqual(loaded_gm(input_x), gm(input_x))
def test_package_fx_wrap(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a):
return len(a)
traced = torch.fx.symbolic_trace(TestModule())
f = BytesIO()
with torch.package.PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)
pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))
if __name__ == "__main__":
run_tests()