1870 lines
68 KiB
Python
1870 lines
68 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import math
|
|
import numbers
|
|
import operator
|
|
import pickle
|
|
import sys
|
|
import sympy
|
|
import tempfile
|
|
import unittest
|
|
from types import BuiltinFunctionType
|
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.fx.experimental.meta_tracer
|
|
import torch.fx.experimental.optimization as optimization
|
|
from torch.fx._symbolic_trace import symbolic_trace
|
|
from torch.fx.experimental import merge_matmul
|
|
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
|
from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
|
|
from torch.fx.experimental.partitioner_utils import (
|
|
Device,
|
|
get_latency_of_partitioned_graph,
|
|
get_partition_to_latency_mapping,
|
|
NodeLatency,
|
|
PartitionerConfig,
|
|
PartitionMode,
|
|
)
|
|
from torch.fx.experimental.rewriter import RewritingTracer
|
|
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.node import Node
|
|
from torch.fx.operator_schemas import (
|
|
_torchscript_type_to_python_type,
|
|
create_type_hint,
|
|
normalize_function,
|
|
normalize_module,
|
|
type_matches,
|
|
)
|
|
from torch.fx.passes import graph_manipulation
|
|
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
|
|
from torch.fx.passes.shape_prop import ShapeProp
|
|
from torch.fx.passes.split_module import split_module
|
|
from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
onlyCPU,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
|
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
try:
|
|
import torchvision.models
|
|
from torchvision.models import resnet18
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
skipIfNoMkldnn = unittest.skipIf(
|
|
not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
|
|
"no MKLDNN",
|
|
)
|
|
|
|
|
|
def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
|
|
return GraphModule(
|
|
root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
|
|
RewritingTracer().trace(root),
|
|
)
|
|
|
|
|
|
class TestFXExperimental(JitTestCase):
|
|
def test_find_single_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(1)
|
|
b = torch.rand(1)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 150, 1),
|
|
Device("dev_2", 125, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
assert dag.nodes[0].logical_device_ids == [1]
|
|
|
|
def test_lack_of_devices(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_large_node_error(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
linear = self.linear(a)
|
|
add = linear + a
|
|
return add
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 40, 0),
|
|
Device("dev_1", 40, 0),
|
|
Device("dev_2", 40, 0),
|
|
Device("dev_3", 40, 0),
|
|
Device("dev_4", 40, 0),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_partition_node_manipulation(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
add_2 = add_1 + torch.rand(4)
|
|
add_3 = add_2 + torch.rand(4)
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a, b = torch.rand(4), torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 1000, 0)]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
partition = partitioner.partitions[0]
|
|
assert partition.used_mem_bytes == 112
|
|
# Select add_2 node to remove
|
|
selected_node = None
|
|
for node in partition.nodes:
|
|
if node.name == "add_2":
|
|
selected_node = node
|
|
partition.remove_node(selected_node)
|
|
assert partition.used_mem_bytes == 80
|
|
|
|
def test_size_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
linear = self.linear(add_1)
|
|
add_2 = linear + self.c
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
for i, node in enumerate(dag.nodes):
|
|
assert node.logical_device_ids == [i]
|
|
|
|
def test_partition_device_mapping(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
b = torch.rand(4)
|
|
add_1 = a + b
|
|
linear_1 = self.linear(add_1)
|
|
add_2 = torch.rand(4) + a
|
|
add_3 = add_2 + linear_1
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
for i, node in enumerate(dag.nodes):
|
|
if i == 1:
|
|
assert node.logical_device_ids == [1]
|
|
else:
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_sparse_nn_partition(self):
|
|
class MyRecommendationModule(torch.nn.Module):
|
|
def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
|
|
layers = torch.nn.ModuleList()
|
|
for _ in range(num_of_layers):
|
|
ll = torch.nn.Linear(input_size, output_size)
|
|
layers.append(ll)
|
|
layers.append(torch.nn.ReLU())
|
|
return layers
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
layers = self.create_mlp(4, 4, 4)
|
|
self.bottom_layers = torch.nn.Sequential(*layers)
|
|
layers = self.create_mlp(3, 24, 24)
|
|
self.top_layers = torch.nn.Sequential(*layers)
|
|
self.embedding_layers = torch.nn.ModuleList()
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
for i in range(3):
|
|
el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
|
|
def forward(self, a, b, offset):
|
|
x = self.bottom_layers(a)
|
|
y = []
|
|
c = []
|
|
for i in range(len(self.embedding_layers)):
|
|
temp = torch.randint(10, (8,))
|
|
c.append(temp + b)
|
|
for i in range(len(self.embedding_layers)):
|
|
if i % 2 == 0:
|
|
y.append(self.embedding_layers[i](c[i], offset))
|
|
else:
|
|
y.append(
|
|
self.embedding_layers[i](torch.randint(10, (8,)), offset)
|
|
)
|
|
z = torch.cat([x] + y, dim=1)
|
|
p = self.top_layers(z)
|
|
return p
|
|
|
|
m = MyRecommendationModule()
|
|
a = torch.rand(2, 4)
|
|
b = torch.randint(10, (8,))
|
|
offset = torch.randint(1, (2,))
|
|
traced = symbolic_trace(m)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
|
|
devices = [
|
|
Device("dev_0", 33000000, 0),
|
|
Device("dev_1", 33000000, 1),
|
|
Device("dev_2", 33000000, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
|
|
assert len(module_with_submodules.graph.nodes) == 24
|
|
|
|
def test_partition_latency(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + add_3
|
|
return add_4
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
"""Given a fx module, generate node latency for each node
|
|
based on the size of each node
|
|
"""
|
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {"output", "placeholder", "get_attr"}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
|
|
)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, node.size_bytes.output_size
|
|
)
|
|
return node_to_latency_mapping
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
|
|
partitioner = Partitioner()
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitions, node_to_latency_mapping
|
|
)
|
|
for p in partition_to_latency_mapping:
|
|
if p.partition_id == 0:
|
|
assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
|
|
else:
|
|
assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
|
|
transfer_rate_bytes_per_sec = 2
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
|
|
)
|
|
assert critical_path_latency_sec == 208.0
|
|
|
|
def test_cost_aware_partition(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + torch.rand(4)
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {"output", "placeholder", "get_attr"}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, 1
|
|
)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, node.size_bytes.output_size
|
|
)
|
|
return node_to_latency_mapping
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2),
|
|
Device("dev_3", 125, 3),
|
|
]
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
partitioner_config = PartitionerConfig(
|
|
devices,
|
|
mode=PartitionMode.cost_aware,
|
|
transfer_rate_bytes_per_sec=2,
|
|
node_to_latency_mapping=node_to_latency_mapping,
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitions, node_to_latency_mapping
|
|
)
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions,
|
|
partition_to_latency_mapping,
|
|
partitioner_config.transfer_rate_bytes_per_sec,
|
|
)
|
|
assert critical_path_latency_sec == 160.0
|
|
|
|
def test_aot_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.b = torch.rand(4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + self.b
|
|
add_2 = self.c + add_1
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
node_to_partition_id = {}
|
|
partition_to_logical_devices = {}
|
|
count = 0
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
for node in traced.graph.nodes:
|
|
if node.op not in {"placeholder", "get_attr", "output"}:
|
|
node_to_partition_id[node] = count
|
|
partition_to_logical_devices[count] = [0]
|
|
count += 1
|
|
devices = [Device("dev_0", 200, 0)]
|
|
partitioner_config = PartitionerConfig(
|
|
devices=devices,
|
|
mode=PartitionMode.aot_based,
|
|
node_to_partition_mapping=node_to_partition_id,
|
|
partition_to_logical_device_mapping=partition_to_logical_devices,
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(module_with_submodules(a), traced(a))
|
|
for node in dag.nodes:
|
|
assert node.size_bytes == 48
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_replace_target_nodes_with(self):
|
|
class testModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = testModule()
|
|
traced = symbolic_trace(m)
|
|
input1 = torch.randn(1)
|
|
input2 = torch.randn(1)
|
|
assert (input1 + input2) == traced(input1, input2)
|
|
graph_manipulation.replace_target_nodes_with(
|
|
fx_module=traced,
|
|
old_op="call_function",
|
|
old_target=operator.add,
|
|
new_op="call_function",
|
|
new_target=operator.mul,
|
|
)
|
|
assert (input1 * input2) == traced(input1, input2)
|
|
|
|
def test_saturate_host(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + add_3
|
|
return add_4
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
devices = [
|
|
Device("dev_0", 200, 0),
|
|
Device("dev_1", 200, 1),
|
|
Device("dev_2", 100, 2),
|
|
Device("dev_3", 100, 3),
|
|
Device("dev_4", 200, 4),
|
|
Device("dev_5", 100, 5),
|
|
]
|
|
partitioner = Partitioner()
|
|
# Without host saturation, the model will be split into two partitions.
|
|
# dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
|
|
partitioner_config = PartitionerConfig(devices, saturate_host=True)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
|
|
partitions = partitioner.partitions
|
|
self.assertEqual(len(partitions), 2)
|
|
# With host saturation, partition 1 will be replicated to dev_4, and partition 2
|
|
# will be replicated to dev_2.
|
|
self.assertEqual(partitions[0].logical_device_ids, [0, 4])
|
|
self.assertEqual(partitions[1].logical_device_ids, [1, 2])
|
|
|
|
@skipIfNoTorchVision
|
|
def test_conv_bn_fusion(self):
|
|
rn18 = resnet18().eval()
|
|
traced = symbolic_trace(rn18)
|
|
fused = optimization.fuse(traced)
|
|
|
|
self.assertTrue(
|
|
all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
|
)
|
|
|
|
N, C, H, W = 20, 3, 224, 224
|
|
inp = torch.randn(N, C, H, W)
|
|
|
|
self.assertEqual(fused(inp), rn18(inp))
|
|
|
|
def test_conv_bn_fusion_not_running_state(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
|
|
self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
|
|
traced = symbolic_trace(model)
|
|
fused = optimization.fuse(traced)
|
|
inp = torch.randn([1, 32, 50, 50])
|
|
|
|
# bn need not be folded in conv
|
|
self.assertTrue(
|
|
any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
|
)
|
|
self.assertEqual(fused(inp), model(inp))
|
|
|
|
def test_conv_bn_fusion_mixed_dtype(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
|
|
self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
|
|
traced = symbolic_trace(model)
|
|
fused = optimization.fuse(traced)
|
|
inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
|
|
|
|
self.assertTrue(
|
|
all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
|
)
|
|
self.assertEqual(fused(inp), model(inp))
|
|
|
|
def test_call_to_assert_no_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_meta_tracer(self):
|
|
class MetaTracerTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
|
|
self.layernorm = torch.nn.LayerNorm(16)
|
|
|
|
def forward(self, x):
|
|
emb = self.emb(x)
|
|
emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
|
|
lol = self.layernorm(emb)
|
|
return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
|
|
|
|
mttm = MetaTracerTestModule()
|
|
for BS in [15, 35]:
|
|
x = torch.zeros(BS, dtype=torch.long).random_(42)
|
|
meta_args = {'x' : x.to(device='meta')}
|
|
gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
|
|
torch.testing.assert_close(gm(x), mttm(x))
|
|
|
|
# Test serialization/deserialization
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
|
|
pickle.dump(gm, f)
|
|
|
|
with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
|
|
loaded = pickle.load(f)
|
|
|
|
torch.testing.assert_close(loaded(x), mttm(x))
|
|
|
|
|
|
def test_call_to_assert_with_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, "test message"
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, "test message"):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_empty_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, ""
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_multiline_message(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
assert a == b, error_msg
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, error_msg):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_subgraph_creation(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x, y):
|
|
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
w = self.linear(y).clamp(min=0.0, max=1.0)
|
|
return z + w
|
|
|
|
# symbolically trace model
|
|
my_module = MyModule()
|
|
my_module_traced = symbolic_trace(my_module)
|
|
|
|
# random mod partitioning
|
|
partition_counter = 0
|
|
NPARTITIONS = 3
|
|
|
|
# Add some random meta info to make sure it is kept around.
|
|
for node in my_module_traced.graph.nodes:
|
|
if node.op != "output":
|
|
node.meta["test_meta_info"] = True
|
|
|
|
def mod_partition(node: Node):
|
|
nonlocal partition_counter
|
|
partition = partition_counter % NPARTITIONS
|
|
partition_counter = (partition_counter + 1) % NPARTITIONS
|
|
return partition
|
|
|
|
# split module in module with submodules
|
|
module_with_submodules = split_module(
|
|
my_module_traced, my_module, mod_partition
|
|
)
|
|
|
|
# Check that test_meta_info was still on all nodes.
|
|
submodules = dict(module_with_submodules.named_modules())
|
|
for node in module_with_submodules.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod = submodules[node.target]
|
|
self.assertTrue(isinstance(submod, torch.fx.GraphModule))
|
|
for submod_node in submod.graph.nodes:
|
|
if submod_node.op != "output":
|
|
stored_op = submod_node.meta.get("test_meta_info")
|
|
self.assertTrue(stored_op is not None and stored_op)
|
|
|
|
x = torch.rand(3, 4)
|
|
y = torch.rand(3, 4)
|
|
|
|
orig_out = my_module_traced(x, y)
|
|
submodules_out = module_with_submodules(x, y)
|
|
|
|
self.assertEqual(orig_out, submodules_out)
|
|
|
|
def test_split_module_dead_code(self):
|
|
class ModWithDeadCode(torch.nn.Module):
|
|
def forward(self, x):
|
|
output = x * 2 # we want this
|
|
dead_line = x + 2 # this is dead
|
|
return output
|
|
|
|
mod = ModWithDeadCode()
|
|
traced = torch.fx.symbolic_trace(mod)
|
|
|
|
# split into before (0), target (1), and after(2)
|
|
saw_mul = False
|
|
|
|
def split_callback(n):
|
|
nonlocal saw_mul
|
|
if n.target == operator.mul:
|
|
saw_mul = True
|
|
return 1
|
|
|
|
if not saw_mul:
|
|
return 0
|
|
if saw_mul:
|
|
return 2
|
|
|
|
split = split_module(traced, mod, split_callback)
|
|
|
|
x = torch.randn((5,))
|
|
torch.testing.assert_close(
|
|
split(x), traced(x)
|
|
)
|
|
|
|
|
|
def test_split_module_kwargs_expansion(self):
|
|
class ModuleWithKwargsExpansion(torch.nn.Module):
|
|
def forward(self, x, **kwargs):
|
|
return x + kwargs['foo']
|
|
|
|
mod = ModuleWithKwargsExpansion()
|
|
traced = torch.fx.symbolic_trace(mod)
|
|
|
|
seen_getitem = False
|
|
|
|
def split_callback(n):
|
|
nonlocal seen_getitem
|
|
split_idx = int(seen_getitem)
|
|
if n.target == operator.getitem:
|
|
seen_getitem = True
|
|
return split_idx
|
|
|
|
split = split_module(traced, mod, split_callback)
|
|
|
|
x = torch.randn(5, 3)
|
|
foo = torch.randn(5, 3)
|
|
torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_subgraph_trivial_resnet(self):
|
|
# Smoke test trivially splitting resnet into 1 partition works
|
|
# There was an issue before causing submodule names to be aliased
|
|
m = resnet18()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(64, 3, 7, 7)
|
|
module_with_submodules = split_module(traced, m, lambda node: 0)
|
|
module_with_submodules(a)
|
|
|
|
def test_split_module_default_arg(self):
|
|
class ModelToTrace(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = torch.nn.Linear(512, 512)
|
|
|
|
def forward(self, x, targets=None):
|
|
x = self.lin(x)
|
|
|
|
if targets is not None:
|
|
x = x + targets
|
|
|
|
return x
|
|
|
|
mtt = ModelToTrace()
|
|
traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
|
|
|
|
split = split_module(traced, mtt, lambda node: 0)
|
|
|
|
x = torch.randn(50, 512)
|
|
torch.testing.assert_close(split(x), traced(x))
|
|
|
|
def test_normalize_binary_operators(self):
|
|
ops_to_test = {
|
|
torch.add,
|
|
torch.mul,
|
|
torch.sub,
|
|
torch.div,
|
|
torch.floor_divide,
|
|
torch.remainder,
|
|
torch.eq,
|
|
torch.ne,
|
|
torch.lt,
|
|
torch.le,
|
|
torch.gt,
|
|
torch.ge,
|
|
}
|
|
|
|
# Test Tensor/Tensor callsite
|
|
for op in ops_to_test:
|
|
|
|
class WrapperMod(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return op(x, y)
|
|
|
|
traced = symbolic_trace(WrapperMod())
|
|
normalized = NormalizeOperators(traced).transform()
|
|
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
|
torch.testing.assert_close(traced(x, y), normalized(x, y))
|
|
self.assertFalse(
|
|
any(n.target in ops_to_test for n in normalized.graph.nodes)
|
|
)
|
|
|
|
# Test Tensor/scalar callsite
|
|
for op in ops_to_test:
|
|
|
|
class WrapperMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
return op(x, 42)
|
|
|
|
traced = symbolic_trace(WrapperMod())
|
|
normalized = NormalizeOperators(traced).transform()
|
|
x = torch.randn(3, 4)
|
|
torch.testing.assert_close(traced(x), normalized(x))
|
|
self.assertFalse(
|
|
any(n.target in ops_to_test for n in normalized.graph.nodes)
|
|
)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_normalize_args(self):
|
|
m = resnet18()
|
|
|
|
class FunctionalTracer(torch.fx.Tracer):
|
|
def is_leaf_module(
|
|
self, m: torch.nn.Module, module_qualified_name: str
|
|
) -> bool:
|
|
# `leaves` contains the set of standard `nn.Modules` that are not
|
|
# currently symbolically traceable. Ideally this set would be empty
|
|
leaves = {torch.nn.BatchNorm2d}
|
|
return type(m) in leaves
|
|
|
|
traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
|
|
|
input = torch.randn(5, 3, 224, 224)
|
|
ref_outs = traced(input)
|
|
|
|
ShapeProp(traced).propagate(input)
|
|
traced = NormalizeArgs(traced).transform()
|
|
|
|
modules = dict(traced.named_modules())
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target != operator.add:
|
|
self.assertEqual(len(node.args), 0)
|
|
elif node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
self.assertEqual(len(node.args), 0)
|
|
traced(input)
|
|
self.assertEqual(traced(input), ref_outs)
|
|
|
|
def test_normalize_modules_exhaustive(self):
|
|
"""
|
|
Exhaustively test `Node.normalized_arguments` on all standard
|
|
torch.nn Module classes
|
|
"""
|
|
for test_params in module_tests + new_module_tests:
|
|
if "constructor" not in test_params:
|
|
constructor = getattr(torch.nn, test_params["module_name"])
|
|
else:
|
|
constructor = test_params["constructor"]
|
|
|
|
if "constructor_args" not in test_params:
|
|
args = ()
|
|
else:
|
|
args = test_params["constructor_args"]
|
|
|
|
mod = constructor(*args)
|
|
# Skip modules that are not standard `torch.nn`
|
|
# instances, including functionals. (functionals
|
|
# are tested in test_normalize_args)
|
|
if mod.__class__.__name__ not in dir(torch.nn):
|
|
continue
|
|
|
|
if "input_fn" not in test_params:
|
|
inputs = torch.randn(test_params["input_size"])
|
|
else:
|
|
inputs = test_params["input_fn"]()
|
|
|
|
if not isinstance(inputs, (tuple, list)):
|
|
inputs = (inputs,)
|
|
|
|
params = ", ".join(f"v{i}" for i in range(len(inputs)))
|
|
|
|
# Generate a class to wrap this standard `nn.Module` instance
|
|
test_classname = f"Test{mod.__class__.__name__}"
|
|
test_mod_code = f"""
|
|
class {test_classname}(torch.nn.Module):
|
|
def __init__(self, mod):
|
|
super().__init__()
|
|
self.mod = mod
|
|
|
|
def forward(self, {params}):
|
|
return self.mod({params})
|
|
"""
|
|
|
|
gbls = {"torch": torch}
|
|
exec(test_mod_code, gbls)
|
|
|
|
test_instance = gbls[test_classname](mod)
|
|
traced = symbolic_trace(test_instance)
|
|
|
|
# Use `Node.normalized_arguments` to get a new set of arguments
|
|
# to feed to the Module. Then, rewrite the node to only take
|
|
# in those arguments as kwargs
|
|
modules = dict(traced.named_modules())
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
normalized_args = node.normalized_arguments(traced)
|
|
normalized_args2 = normalize_module(
|
|
traced, node.target, node.args, node.kwargs
|
|
)
|
|
assert normalized_args == normalized_args2
|
|
assert normalized_args
|
|
node.args = normalized_args.args
|
|
node.kwargs = normalized_args.kwargs
|
|
|
|
traced.recompile()
|
|
|
|
# These Modules have an RNG in their forward, so testing
|
|
# correctness by comparing outputs is not correct. Skip that
|
|
# check for these
|
|
stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}
|
|
|
|
if mod.__class__.__name__ not in stochastic_modules:
|
|
self.assertEqual(traced(*inputs), mod(*inputs))
|
|
|
|
traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
|
|
modules = dict(traced.named_modules())
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
self.assertEqual(len(node.args), 0)
|
|
|
|
def test_normalize_args_preserve_meta(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, a):
|
|
return torch.add(a, 3)
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.add:
|
|
node.meta["my_key"] = 7
|
|
break
|
|
else:
|
|
self.fail("Didn't find call_function torch.add")
|
|
|
|
input = torch.randn(2, 3)
|
|
ShapeProp(traced).propagate(input)
|
|
traced = NormalizeArgs(traced).transform()
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.add:
|
|
self.assertTrue("my_key" in node.meta)
|
|
self.assertEqual(node.meta["my_key"], 7)
|
|
break
|
|
else:
|
|
self.fail("Didn't find call_function torch.add")
|
|
|
|
def test_normalize_args_perserve_type(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, a: List[torch.Tensor]):
|
|
return torch.add(a[0], a[1])
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
traced = NormalizeArgs(traced).transform()
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "placeholder":
|
|
self.assertEqual(node.type, List[torch.Tensor])
|
|
|
|
@skipIfNoTorchVision
|
|
def test_annotate_returns_with_schema(self):
|
|
m = resnet18()
|
|
|
|
traced_modules = symbolic_trace(m)
|
|
traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
|
|
for node in traced_modules_annotated.graph.nodes:
|
|
if node.type is None:
|
|
check = (node.op, node.target)
|
|
self.assertIn(
|
|
check,
|
|
{
|
|
("placeholder", "x"),
|
|
("call_module", "maxpool"),
|
|
("call_function", operator.add),
|
|
("call_function", torch.flatten),
|
|
("output", "output"),
|
|
}
|
|
)
|
|
|
|
# Smoke test torchscript compilation since now we're emitting type annotations
|
|
torch.jit.script(traced_modules_annotated)
|
|
|
|
class FunctionalTracer(torch.fx.Tracer):
|
|
def is_leaf_module(
|
|
self, m: torch.nn.Module, module_qualified_name: str
|
|
) -> bool:
|
|
# `leaves` contains the set of standard `nn.Modules` that are not
|
|
# currently symbolically traceable. Ideally this set would be empty
|
|
leaves = {torch.nn.BatchNorm2d}
|
|
return type(m) in leaves
|
|
|
|
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
|
|
|
traced_functionals_annotated = AnnotateTypesWithSchema(
|
|
traced_functionals
|
|
).transform()
|
|
for node in traced_functionals_annotated.graph.nodes:
|
|
if node.type is None:
|
|
check = (node.op, node.target)
|
|
excluded_nodes = {
|
|
("placeholder", "x"),
|
|
# Return type differs based on boolean dispatch :(
|
|
("call_function", torch.nn.functional.max_pool2d),
|
|
("output", "output"),
|
|
}
|
|
# AnnotateTypesWithSchema doesn't work with bound C++ functions
|
|
if not isinstance(node.target, BuiltinFunctionType):
|
|
self.assertIn(check, excluded_nodes)
|
|
|
|
# Smoke test torchscript compilation since now we're emitting type annotations
|
|
torch.jit.script(traced_functionals_annotated)
|
|
|
|
def test_annotate_getitem_node(self):
|
|
class CustomType:
|
|
pass
|
|
|
|
class CustomNamedTuple(NamedTuple):
|
|
x: int
|
|
y: float
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple):
|
|
inp_0 = inp[0]
|
|
inp_1 = inp[1]
|
|
inp2_0 = inp2[0]
|
|
inp3_x = inp3.x
|
|
inp3_y = inp3.y
|
|
return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
|
|
|
|
my_module = MyModule()
|
|
my_module_traced = torch.fx.symbolic_trace(my_module)
|
|
|
|
# by default, fx transform loses type annotation of getitem nodes.
|
|
for node in my_module_traced.graph.nodes:
|
|
if node.target == operator.getitem:
|
|
assert node.type is None
|
|
|
|
annotate_getitem_nodes(my_module_traced.graph)
|
|
|
|
for node in my_module_traced.graph.nodes:
|
|
if node.target == operator.getitem:
|
|
self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
|
|
|
|
def test_subgraph_uniquename(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b, c, d):
|
|
add_1 = a + b
|
|
add_2 = add_1 + c
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + d
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
|
|
mm = MyModule()
|
|
traced = symbolic_trace(mm)
|
|
|
|
def split_cb(node: torch.fx.Node):
|
|
if node.name == "a" or node.name == "b" or node.name == "add":
|
|
return 0
|
|
else:
|
|
return 1
|
|
|
|
module_with_submodule = split_module(traced, mm, split_cb)
|
|
self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
|
|
|
|
def test_split_qualname_mapping(self):
|
|
d_hid = 4
|
|
|
|
class ExampleCode(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.lin = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x):
|
|
x = torch.mm(x, self.mm_param)
|
|
x = torch.relu(x)
|
|
x = torch.mm(x, self.mm_param)
|
|
x = self.lin(x)
|
|
x = torch.relu(x)
|
|
x = torch.mm(x, self.mm_param2)
|
|
x = self.lin(x)
|
|
return x
|
|
|
|
my_module = ExampleCode()
|
|
my_module_traced = symbolic_trace(my_module)
|
|
|
|
part_idx = 0
|
|
|
|
def split_callback(n : torch.fx.Node):
|
|
nonlocal part_idx
|
|
if (n.op, n.target) == ('call_module', 'lin'):
|
|
part_idx += 1
|
|
return part_idx
|
|
|
|
# split module in module with submodules
|
|
qualname_map : Dict[str, str] = {}
|
|
module_with_submodules = split_module(
|
|
my_module_traced, my_module, split_callback, qualname_map
|
|
)
|
|
expected_qualname_map = {
|
|
'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
|
|
}
|
|
self.assertEqual(qualname_map, expected_qualname_map)
|
|
|
|
def test_traceable_function_with_nonstandard_name(self):
|
|
def foo(x):
|
|
return torch.relu(x)
|
|
|
|
traced = symbolic_trace_with_rewrite(foo)
|
|
|
|
def test_to_folder(self):
|
|
class Test(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.W = torch.nn.Parameter(torch.randn(2))
|
|
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.attr = torch.randn(2)
|
|
self.register_buffer("attr2", torch.randn(2))
|
|
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
|
|
|
|
def forward(self, x):
|
|
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
|
|
|
|
mod = symbolic_trace(Test())
|
|
module_name = "Foo"
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_dir = Path(tmp_dir)
|
|
mod.to_folder(tmp_dir, module_name)
|
|
# Recipe taken from here:
|
|
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
|
import importlib.util
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
module_name, tmp_dir / "__init__.py"
|
|
)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
t = torch.randn(2, 2)
|
|
self.assertEqual(module.Foo()(t), mod(t))
|
|
|
|
def test_fetch(self):
|
|
attrs_for_lowering: Dict[str, List[str]] = {
|
|
"torch.nn.modules.conv.Conv2d": [
|
|
"weight",
|
|
"bias",
|
|
"kernel_size",
|
|
"stride",
|
|
"padding",
|
|
"dilation",
|
|
"groups",
|
|
"padding_mode",
|
|
],
|
|
"torch.nn.modules.batchnorm.BatchNorm2d": [
|
|
"weight",
|
|
"bias",
|
|
"running_mean",
|
|
"running_var",
|
|
"eps",
|
|
],
|
|
}
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 2)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, a):
|
|
a = self.conv(a)
|
|
a += a
|
|
return self.bn(a)
|
|
|
|
mod = TestModule()
|
|
traced = symbolic_trace(mod)
|
|
lift_lowering_attrs_to_nodes(traced)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
assert hasattr(node, "attrs_for_lowering")
|
|
para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
|
|
|
|
# node.attrs_for_lowering has an addition field of class name
|
|
assert len(para_list) + 1 == len(node.attrs_for_lowering)
|
|
for p_name in para_list:
|
|
assert p_name in node.attrs_for_lowering
|
|
|
|
def test_merge_matmuls(self):
|
|
"""
|
|
A collection of test cases for torch.fx.experimental.merge_matmul,
|
|
a graph transformation that merges matrix multiplication operations.
|
|
"""
|
|
# Utility function for counting matmuls for test assertions.
|
|
def _count_matmuls(mod):
|
|
gm = torch.fx.symbolic_trace(mod)
|
|
|
|
num_matmuls = 0
|
|
for node in gm.graph.nodes:
|
|
if node.target == torch.matmul:
|
|
num_matmuls += 1
|
|
|
|
return num_matmuls
|
|
|
|
# Simple test case in which there are two matmuls of the same size to merge.
|
|
class SimpleMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x, y):
|
|
a = torch.matmul(x, self.rhs)
|
|
b = torch.matmul(y, self.rhs)
|
|
return a + b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
b = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct SimpleMergeMatmulModule and call merge_matmul on it.
|
|
module = SimpleMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a, b)
|
|
after = opt_module(a, b)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have 2 matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Test case in which there are multiple matmuls of different sizes to merge.
|
|
class FiveMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, a, b, c, d, e):
|
|
s = torch.tensor([])
|
|
matmuls = []
|
|
|
|
# For some reason using a list comprehension or for-loop for this
|
|
# doesn't work.
|
|
matmuls.append(torch.matmul(a, self.rhs))
|
|
matmuls.append(torch.matmul(b, self.rhs))
|
|
matmuls.append(torch.matmul(c, self.rhs))
|
|
matmuls.append(torch.matmul(d, self.rhs))
|
|
matmuls.append(torch.matmul(e, self.rhs))
|
|
|
|
for m in matmuls:
|
|
s += torch.sum(m)
|
|
|
|
return s
|
|
|
|
# Initialize inputs.
|
|
inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
|
|
|
|
# Initialize RHS.
|
|
rhs = torch.randn(5, 4)
|
|
|
|
# Construct FiveMergeMatmulModule and call merge_matmul on it.
|
|
module = FiveMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(*inputs)
|
|
after = opt_module(*inputs)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have len(inputs) matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), len(inputs))
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Simple test case in which two matmuls cannot be merged due to a data dependency between
|
|
# the LHS operands.
|
|
class UnmergeableMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x):
|
|
a = torch.matmul(x, self.rhs)
|
|
a_abs = torch.abs(a)
|
|
b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
|
|
return b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct UnmergeableMatmulModule and call merge_matmul on it.
|
|
module = UnmergeableMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a)
|
|
after = opt_module(a)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; the number of matrix multiplcations should not have changed.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 2)
|
|
|
|
def test_type_matches(self):
|
|
should_be_equal = [
|
|
(int, int),
|
|
(numbers.Number, int),
|
|
(numbers.Number, float),
|
|
(int, type(torch.float)),
|
|
(Union[int, float], int),
|
|
(Union[int, float], float),
|
|
(List[int], int),
|
|
(List[int], create_type_hint([int, int])),
|
|
(List[int], create_type_hint((int, int))),
|
|
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
|
(
|
|
List[torch.Tensor],
|
|
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
|
),
|
|
(torch.Tensor, torch.nn.Parameter),
|
|
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
|
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
|
(
|
|
List[torch.Tensor],
|
|
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
|
),
|
|
(torch.Tensor, torch.nn.Parameter),
|
|
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
|
(Optional[List[torch.Tensor]], List[torch.Tensor]),
|
|
(Optional[List[int]], List[int]),
|
|
]
|
|
for sig_type, arg_type in should_be_equal:
|
|
self.assertTrue(type_matches(sig_type, arg_type))
|
|
|
|
should_fail = [
|
|
(int, float),
|
|
(Union[int, float], str),
|
|
(List[torch.Tensor], List[int]),
|
|
]
|
|
|
|
for sig_type, arg_type in should_fail:
|
|
self.assertFalse(type_matches(sig_type, arg_type))
|
|
|
|
@skipIfNoMkldnn
|
|
def test_optimize_for_inference_cpu(self):
|
|
import torch.nn as nn
|
|
|
|
class Foo(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
layers = []
|
|
layers2 = []
|
|
for _ in range(10):
|
|
layers.append(nn.Conv2d(3, 3, 1))
|
|
layers.append(nn.BatchNorm2d(3))
|
|
layers.append(nn.ReLU())
|
|
|
|
layers2.append(nn.Conv2d(3, 3, 1))
|
|
layers2.append(nn.BatchNorm2d(3))
|
|
layers2.append(nn.ReLU())
|
|
self.model = nn.Sequential(*layers)
|
|
self.model2 = nn.Sequential(*layers2)
|
|
|
|
def forward(self, x):
|
|
return self.model(x) + self.model2(x)
|
|
|
|
N, C, H, W, = (
|
|
1,
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(N, C, H, W)
|
|
with torch.no_grad():
|
|
model = Foo().eval()
|
|
optimized_model = optimization.optimize_for_inference(model)
|
|
torch.testing.assert_close(model(inp), optimized_model(inp))
|
|
|
|
optimized_model2 = optimization.optimize_for_inference(
|
|
model, pass_config={"remove_dropout": False}
|
|
)
|
|
torch.testing.assert_close(model(inp), optimized_model2(inp))
|
|
|
|
@skipIfNoTorchVision
|
|
@skipIfNoMkldnn
|
|
def test_optimize_for_inference_cpu_torchvision(self):
|
|
models = [
|
|
torchvision.models.resnet18,
|
|
torchvision.models.resnet50,
|
|
torchvision.models.densenet121,
|
|
torchvision.models.shufflenet_v2_x1_0,
|
|
torchvision.models.vgg16,
|
|
torchvision.models.mobilenet_v2,
|
|
torchvision.models.mnasnet1_0,
|
|
torchvision.models.resnext50_32x4d,
|
|
]
|
|
with torch.no_grad():
|
|
for model_type in models:
|
|
model = model_type()
|
|
C, H, W, = (
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(3, C, H, W)
|
|
model(inp)
|
|
model.eval()
|
|
inp = torch.randn(1, C, H, W)
|
|
heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
|
|
optimized_model = optimization.optimize_for_inference(model)
|
|
|
|
orig_out = model(inp)
|
|
new_out = optimized_model(inp)
|
|
torch.testing.assert_close(orig_out, new_out)
|
|
|
|
|
|
class TestNormalizeOperators(JitTestCase):
|
|
@onlyCPU
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
|
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
|
|
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa"}
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
if isinstance(op.op, torch._ops.OpOverload):
|
|
self.skipTest("normalize operator doesn't work on torch.ops")
|
|
for sample_input in sample_inputs_itr:
|
|
unsupported_arg_type = False
|
|
arg_values = [sample_input.input] + list(sample_input.args)
|
|
kwarg_values = sample_input.kwargs
|
|
arg_types = []
|
|
kwarg_types = {}
|
|
|
|
def jit_infer_type(v):
|
|
inferred_arg_type = torch._C._jit_try_infer_type(v)
|
|
assert inferred_arg_type.success()
|
|
t = _torchscript_type_to_python_type(inferred_arg_type.type())
|
|
return t
|
|
|
|
for v in arg_values:
|
|
if isinstance(v, torch.Tensor):
|
|
arg_types.append(type(v))
|
|
else:
|
|
if isinstance(v, complex):
|
|
# Complex type not supported in FX
|
|
unsupported_arg_type = True
|
|
arg_types.append(jit_infer_type(v))
|
|
|
|
for k, v in kwarg_values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
kwarg_types[k] = type(v)
|
|
else:
|
|
if isinstance(v, complex):
|
|
# Complex type not supported in FX
|
|
unsupported_arg_type = True
|
|
kwarg_types[k] = jit_infer_type(v)
|
|
|
|
if unsupported_arg_type:
|
|
continue
|
|
# Test normalize_function by itself
|
|
ref_out = op.op(*arg_values, **kwarg_values)
|
|
norm_args_and_kwargs = normalize_function(
|
|
op.op, arg_values, kwarg_values, arg_types, kwarg_types
|
|
)
|
|
if norm_args_and_kwargs is None:
|
|
raise RuntimeError(
|
|
"""
|
|
FX failed to normalize op - add the op to the op_skip list.
|
|
A common reason is if your OpInfo was implemented with a lambda
|
|
- otherwise, file an issue
|
|
"""
|
|
)
|
|
test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
|
|
self.assertEqual(test_out, ref_out)
|
|
|
|
# Test normalized_arguments as part of FX
|
|
if op.name in fx_fail:
|
|
continue
|
|
param_names = []
|
|
param_values = []
|
|
fx_args = []
|
|
for idx, v in enumerate(arg_values):
|
|
if isinstance(v, torch.Tensor):
|
|
param_names.append(f"arg_{idx}")
|
|
param_values.append(v)
|
|
fx_args.append(param_names[-1])
|
|
else:
|
|
fx_args.append(f"{repr(v)}")
|
|
|
|
for k, v in kwarg_values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
param_names.append(k)
|
|
param_values.append(v)
|
|
fx_args.append(f"{k} = {k}")
|
|
else:
|
|
fx_args.append(f"{k} = {repr(v)}")
|
|
|
|
code = f"""
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, {', '.join(param_names)}):
|
|
return torch.{op.name}({', '.join(fx_args)})
|
|
"""
|
|
|
|
g = {"torch": torch, "inf": math.inf}
|
|
exec(code, g)
|
|
TestModule = g["TestModule"]
|
|
|
|
m = TestModule()
|
|
traced = torch.fx.symbolic_trace(m)
|
|
ref_out = traced(*param_values)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function":
|
|
normalized_args = node.normalized_arguments(
|
|
traced, arg_types, kwarg_types
|
|
)
|
|
assert normalized_args
|
|
node.args = normalized_args.args
|
|
node.kwargs = normalized_args.kwargs
|
|
traced.recompile()
|
|
|
|
test_out = traced(*param_values)
|
|
self.assertEqual(test_out, ref_out)
|
|
|
|
def test_normalize_quantized_eb(self):
|
|
target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
|
|
args = (
|
|
torch.empty((2, 3), dtype=torch.uint8),
|
|
torch.empty((2,), dtype=torch.int64),
|
|
torch.empty((2,), dtype=torch.int64),
|
|
)
|
|
norm_args_and_kwargs = normalize_function(
|
|
target, args, normalize_to_only_use_kwargs=True
|
|
)
|
|
self.assertTrue(norm_args_and_kwargs is not None)
|
|
self.assertEqual(
|
|
set(norm_args_and_kwargs.kwargs.keys()),
|
|
{
|
|
"weight",
|
|
"indices",
|
|
"offsets",
|
|
"scale_grad_by_freq",
|
|
"mode",
|
|
"pruned_weights",
|
|
"per_sample_weights",
|
|
"compressed_indices_mapping",
|
|
"include_last_offset",
|
|
},
|
|
)
|
|
self.assertEqual(norm_args_and_kwargs.args, tuple())
|
|
|
|
def test_normalize_args_op_overload(self):
|
|
for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
|
|
inp1 = torch.rand([1])
|
|
inp2 = torch.rand([4])
|
|
args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
|
|
self.assertIs(kwargs["input"], inp1)
|
|
self.assertIs(kwargs["the_template"], inp2)
|
|
|
|
|
|
if TEST_Z3:
|
|
import z3
|
|
|
|
import torch._dynamo.config
|
|
|
|
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
|
|
from torch.utils._sympy.functions import FloorDiv, Mod
|
|
|
|
class TestTranslationValidation(TestCase):
|
|
def _prepare_for_translation_validation(self):
|
|
validator = TranslationValidator()
|
|
|
|
# SymPy symbols.
|
|
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
|
|
|
# Z3 symbols.
|
|
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
|
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
|
|
|
return (s0, s1, s2), (z0, z1, z2), validator
|
|
|
|
def test_sympy_to_z3(self):
|
|
|
|
(
|
|
(s0, s1, s2),
|
|
(z0, z1, z2),
|
|
validator,
|
|
) = self._prepare_for_translation_validation()
|
|
|
|
test_cases = [
|
|
# Integer constants.
|
|
(sympy.S.Zero, z3.IntVal(0)),
|
|
(sympy.S.One, z3.IntVal(1)),
|
|
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
|
(sympy.Integer(2), z3.IntVal(2)),
|
|
(
|
|
s0,
|
|
z0,
|
|
),
|
|
# Arithmetic operations.
|
|
*[
|
|
(op(s0, s1), op(z0, z1))
|
|
for op in (
|
|
operator.add,
|
|
operator.mul,
|
|
operator.pow,
|
|
)
|
|
],
|
|
# Logical operations.
|
|
*[
|
|
(sympy_op(s0, s1), z3_op(z0, z1))
|
|
for sympy_op, z3_op in (
|
|
(sympy.Eq, operator.eq),
|
|
(sympy.Ne, operator.ne),
|
|
(sympy.Lt, operator.lt),
|
|
(sympy.Le, operator.le),
|
|
(sympy.Gt, operator.gt),
|
|
(sympy.Ge, operator.ge),
|
|
)
|
|
],
|
|
# Other operations.
|
|
(
|
|
s0 - s1,
|
|
z0 + z3.IntVal(-1) * z1,
|
|
),
|
|
(
|
|
s0 / s1,
|
|
z3.ToReal(z0) * (z1**-1),
|
|
),
|
|
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
|
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
|
(
|
|
Mod(s2, (s0 / s1)),
|
|
z2
|
|
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
|
* (z3.ToReal(z0) * z1**-1),
|
|
),
|
|
(
|
|
Mod(s2, s0**3),
|
|
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
|
),
|
|
]
|
|
|
|
toZ3 = SympyToZ3(validator)
|
|
for sympy_expr, z3_expr in test_cases:
|
|
result = toZ3.run(sympy_expr)
|
|
self.assertTrue(
|
|
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
|
)
|
|
|
|
def test_sat(self):
|
|
(
|
|
(s0, s1, s2),
|
|
(z0, z1, z2),
|
|
validator,
|
|
) = self._prepare_for_translation_validation()
|
|
|
|
validator.add_source_expr(z0 > 5)
|
|
validator.add_source_expr(z1 / 2 > z0)
|
|
|
|
# Solutions for target is a subset of the solutions for the source.
|
|
validator.add_target_expr(s0 > 20)
|
|
validator.add_target_expr(s1 > s0**2)
|
|
|
|
validator.validate()
|
|
|
|
def test_unsat(self):
|
|
(
|
|
(s0, s1, s2),
|
|
(z0, z1, z2),
|
|
validator,
|
|
) = self._prepare_for_translation_validation()
|
|
|
|
validator.add_source_expr(z0 > 5)
|
|
validator.add_source_expr(z1 / 2 > z0)
|
|
|
|
# Solutions for target is NOT a subset of the solutions for the source.
|
|
validator.add_target_expr(s0 > 20)
|
|
# This expression is less restrictive than its counterpart.
|
|
validator.add_target_expr(s1 > s0 + 2)
|
|
|
|
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
|
validator.validate()
|
|
|
|
def test_z3str(self):
|
|
a = z3.Int("a")
|
|
b = z3.Int("b")
|
|
special = z3.Real("this.size()[2]")
|
|
|
|
test_cases = [
|
|
(z3.IntVal(42), "42"),
|
|
# Variable.
|
|
(a, "a"),
|
|
# Name with special characters.
|
|
(special, "this.size()[2]"),
|
|
# Renamed function fpplications.
|
|
(a != b, "(!= a b)"),
|
|
(a ** b, "(pow a b)"),
|
|
# Chain of associative operations.
|
|
*[
|
|
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
|
for op, opstr in [
|
|
(operator.add, "+"),
|
|
(operator.mul, "*")
|
|
]
|
|
],
|
|
# Revert 'Not' conversions.
|
|
(a != b, "(!= a b)"),
|
|
(a < b, "(> b a)"),
|
|
(a > b, "(> a b)"),
|
|
# Ignore 'ToInt' and 'ToReal' functions.
|
|
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
|
(z3.ToReal(a + b), "(+ a b)"),
|
|
# Convert to floor division: 'idiv'.
|
|
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
|
]
|
|
|
|
for expr, expected in test_cases:
|
|
self.assertEqual(z3str(expr), expected)
|
|
|
|
|
|
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|