120 lines
3.5 KiB
Python
120 lines
3.5 KiB
Python
import dataclasses
|
|
from dataclasses import field
|
|
from types import CodeType, ModuleType
|
|
from typing import Any, Dict
|
|
|
|
try:
|
|
import dill
|
|
except ImportError:
|
|
dill = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ModuleRecord:
|
|
module: ModuleType
|
|
accessed_attrs: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DummyModule:
|
|
name: str
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExecutionRecord:
|
|
code: CodeType
|
|
globals: Dict[str, Any] = field(default_factory=dict)
|
|
locals: Dict[str, Any] = field(default_factory=dict)
|
|
builtins: Dict[str, Any] = field(default_factory=dict)
|
|
code_options: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def dump(self, f):
|
|
assert dill is not None, "replay_record requires `pip install dill`"
|
|
dill.dump(self, f)
|
|
|
|
@classmethod
|
|
def load(cls, f):
|
|
assert dill is not None, "replay_record requires `pip install dill`"
|
|
return dill.load(f)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExecutionRecorder:
|
|
MOD_EXCLUDES = ["torch", "torch.fx.passes"]
|
|
LOCAL_MOD_PREFIX = "___local_mod_"
|
|
|
|
code: CodeType
|
|
globals: Dict[str, Any] = field(default_factory=dict)
|
|
locals: Dict[str, Any] = field(default_factory=dict)
|
|
builtins: Dict[str, Any] = field(default_factory=dict)
|
|
code_options: Dict[str, Any] = field(default_factory=dict)
|
|
name_to_modrec: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def add_local_var(self, name, var):
|
|
if isinstance(var, ModuleType):
|
|
if self._is_excl(var):
|
|
return
|
|
self.locals[name] = self._add_mod(var)
|
|
else:
|
|
self.locals[name] = var
|
|
|
|
def add_global_var(self, name, var):
|
|
if isinstance(var, ModuleType):
|
|
if self._is_excl(var):
|
|
return
|
|
self.globals[name] = self._add_mod(var)
|
|
else:
|
|
self.globals[name] = var
|
|
|
|
def add_local_mod(self, name, mod):
|
|
assert isinstance(mod, ModuleType)
|
|
if self._is_excl(mod):
|
|
return
|
|
|
|
self.add_global_var(name, mod)
|
|
|
|
def record_module_access(self, mod, name, val):
|
|
if self._is_excl(mod):
|
|
return
|
|
if isinstance(val, ModuleType):
|
|
self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
|
|
return
|
|
|
|
if mod.__name__ in self.name_to_modrec:
|
|
self.name_to_modrec[mod.__name__].accessed_attrs[name] = val
|
|
|
|
def get_record(self):
|
|
return ExecutionRecord(
|
|
self.code,
|
|
ExecutionRecorder._resolve_modules(self.globals),
|
|
ExecutionRecorder._resolve_modules(self.locals),
|
|
self.builtins.copy(),
|
|
self.code_options.copy(),
|
|
)
|
|
|
|
def _add_mod(self, mod):
|
|
if mod.__name__ not in self.name_to_modrec:
|
|
self.name_to_modrec[mod.__name__] = ModuleRecord(mod)
|
|
|
|
return self.name_to_modrec[mod.__name__]
|
|
|
|
@classmethod
|
|
def _is_excl(cls, mod):
|
|
return any(mod.__name__ == excl for excl in cls.MOD_EXCLUDES)
|
|
|
|
# Convert ModuleRecords -> DummyModule tree
|
|
@classmethod
|
|
def _resolve_modules(cls, vars):
|
|
def resolve_module(var):
|
|
if not isinstance(var, ModuleRecord):
|
|
return var
|
|
|
|
dummy_mod = DummyModule(var.module.__name__)
|
|
for attr_name, attr_value in var.accessed_attrs.items():
|
|
attr_value = resolve_module(attr_value)
|
|
dummy_mod.__setattr__(attr_name, attr_value)
|
|
|
|
return dummy_mod
|
|
|
|
return {k: resolve_module(v) for k, v in vars.items()}
|