pytorch/torch/_dynamo/replay_record.py

120 lines
3.5 KiB
Python

import dataclasses
from dataclasses import field
from types import CodeType, ModuleType
from typing import Any, Dict
try:
import dill
except ImportError:
dill = None
@dataclasses.dataclass
class ModuleRecord:
module: ModuleType
accessed_attrs: Dict[str, Any] = field(default_factory=dict)
@dataclasses.dataclass
class DummyModule:
name: str
@dataclasses.dataclass
class ExecutionRecord:
code: CodeType
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)
code_options: Dict[str, Any] = field(default_factory=dict)
def dump(self, f):
assert dill is not None, "replay_record requires `pip install dill`"
dill.dump(self, f)
@classmethod
def load(cls, f):
assert dill is not None, "replay_record requires `pip install dill`"
return dill.load(f)
@dataclasses.dataclass
class ExecutionRecorder:
MOD_EXCLUDES = ["torch", "torch.fx.passes"]
LOCAL_MOD_PREFIX = "___local_mod_"
code: CodeType
globals: Dict[str, Any] = field(default_factory=dict)
locals: Dict[str, Any] = field(default_factory=dict)
builtins: Dict[str, Any] = field(default_factory=dict)
code_options: Dict[str, Any] = field(default_factory=dict)
name_to_modrec: Dict[str, Any] = field(default_factory=dict)
def add_local_var(self, name, var):
if isinstance(var, ModuleType):
if self._is_excl(var):
return
self.locals[name] = self._add_mod(var)
else:
self.locals[name] = var
def add_global_var(self, name, var):
if isinstance(var, ModuleType):
if self._is_excl(var):
return
self.globals[name] = self._add_mod(var)
else:
self.globals[name] = var
def add_local_mod(self, name, mod):
assert isinstance(mod, ModuleType)
if self._is_excl(mod):
return
self.add_global_var(name, mod)
def record_module_access(self, mod, name, val):
if self._is_excl(mod):
return
if isinstance(val, ModuleType):
self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
return
if mod.__name__ in self.name_to_modrec:
self.name_to_modrec[mod.__name__].accessed_attrs[name] = val
def get_record(self):
return ExecutionRecord(
self.code,
ExecutionRecorder._resolve_modules(self.globals),
ExecutionRecorder._resolve_modules(self.locals),
self.builtins.copy(),
self.code_options.copy(),
)
def _add_mod(self, mod):
if mod.__name__ not in self.name_to_modrec:
self.name_to_modrec[mod.__name__] = ModuleRecord(mod)
return self.name_to_modrec[mod.__name__]
@classmethod
def _is_excl(cls, mod):
return any(mod.__name__ == excl for excl in cls.MOD_EXCLUDES)
# Convert ModuleRecords -> DummyModule tree
@classmethod
def _resolve_modules(cls, vars):
def resolve_module(var):
if not isinstance(var, ModuleRecord):
return var
dummy_mod = DummyModule(var.module.__name__)
for attr_name, attr_value in var.accessed_attrs.items():
attr_value = resolve_module(attr_value)
dummy_mod.__setattr__(attr_name, attr_value)
return dummy_mod
return {k: resolve_module(v) for k, v in vars.items()}