pytorch/torch/_jit_internal.py

1504 lines
52 KiB
Python

"""
The weak_script annotation needs to be here instead of inside torch/jit/ so it
can be used in other places in torch/ (namely torch.nn) without running into
circular dependency problems
"""
import ast
import builtins
import collections
import contextlib
import enum
import inspect
import io
import pickle
import sys
import threading
import typing
import warnings
import weakref
from textwrap import dedent
from typing import ( # noqa: F401
Any,
Callable,
Dict,
Final,
ForwardRef,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import torch
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
# Explicitly ask to import `torch.distributed.__init__` first.
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
import torch.distributed.rpc
import torch.package._mangling as package_mangling
from torch._awaits import _Await
from torch._C import _Await as CAwait, Future as CFuture
from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch.futures import Future
LockType: Type
try:
import _thread
LockType = _thread.LockType
except ImportError:
import _dummy_thread
LockType = _dummy_thread.LockType
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
weakref.WeakKeyDictionary()
) # noqa: T484
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
class SourceLoader:
def __init__(self):
self.content = {}
def cache(self, fn, source):
self.content[fn] = source
def get_source(self, fn):
return self.content.get(fn)
loader = SourceLoader()
IS_PY39_PLUS = sys.version_info >= (3, 9)
def createResolutionCallbackFromEnv(lookup_base):
"""
Creates a resolution callback that will look up qualified names in an
environment, starting with `lookup_base` for the base of any qualified
names, then proceeding down the lookup chain with the resolved object.
You should not use this directly, it should only be used from the other
createResolutionCallbackFrom* functions.
"""
def lookupInModule(qualified_name, module):
if "." in qualified_name:
parts = qualified_name.split(".")
base = parts[0]
remaining_pieces = ".".join(parts[1:])
module_value = getattr(module, base)
return lookupInModule(remaining_pieces, module_value)
else:
return getattr(module, qualified_name)
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
i = 0
while i < len(expr) and expr[i] not in (",", "[", "]"):
i += 1
# Special case logic for the empty Tuple as a subscript (used
# in the type annotation `Tuple[()]`)
if expr[:i] == "()":
return (), i
base = lookupInModule(expr[:i].strip(), module)
assert base is not None, f"Unresolvable type {expr[:i]}"
if i == len(expr) or expr[i] != "[":
return base, i
assert expr[i] == "["
parts = []
while expr[i] != "]":
part_len = 0
i += 1
part, part_len = parseNestedExpr(expr[i:], module)
parts.append(part)
i += part_len
if len(parts) > 1:
return base[tuple(parts)], i + 1
else:
return base[parts[0]], i + 1
def parseExpr(expr, module):
try:
value, len_parsed = parseNestedExpr(expr, module)
assert len_parsed == len(
expr
), "whole expression was not parsed, falling back to c++ parser"
return value
except Exception:
"""
The python resolver fails in several cases in known unit tests, and is intended
to fall back gracefully to the c++ resolver in general. For example, python 2 style
annotations which are frequent in our unit tests often fail with types e.g. int not
resolvable from the calling frame.
"""
return None
return lambda expr: parseExpr(expr, lookup_base)
def createResolutionCallbackFromFrame(frames_up: int = 0):
"""
Creates a function which, given a string variable name,
returns the value of the variable in the scope of the caller of
the function which called createResolutionCallbackFromFrame (by default).
This is used to enable access in-scope Python variables inside
TorchScript fragments.
frames_up is number of additional frames to go up on the stack.
The default value is 0, which correspond to the frame of the caller
of createResolutionCallbackFromFrame. Also for example, if frames_up is set
to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
will be taken.
For example, the following program prints 2::
def bar():
cb = createResolutionCallbackFromFrame(1)
print(cb("foo"))
def baz():
foo = 2
bar()
baz()
"""
frame = inspect.currentframe()
i = 0
while i < frames_up + 1:
assert frame is not None
frame = frame.f_back
i += 1
assert frame is not None
f_locals = frame.f_locals
f_globals = frame.f_globals
class env:
def __getattr__(self, key):
if key in f_locals:
return f_locals[key]
elif key in f_globals:
return f_globals[key]
elif key in dir(builtins):
return getattr(builtins, key)
return createResolutionCallbackFromEnv(env())
def get_closure(fn):
"""
Get a dictionary of closed over variables from a function
"""
captures = {}
captures.update(fn.__globals__)
for index, captured_name in enumerate(fn.__code__.co_freevars):
captures[captured_name] = fn.__closure__[index].cell_contents
return captures
# [local resolution in python]
# Depending on where a variable is defined, and where it is used, we may
# or may not be able to recover its value when recursively compiling a
# script function. Remember in the general case, a module or function is
# first defined and then later scripted. This means we do not have a
# chance to capture the active frames when the function is defined. Hence any
# name resolution has to happen later on the created closure. The way
# python captures type annotations restricts what we can recover. The
# follow example illustrates the different cases:
#
# class MyGlobalClass:
# ...
# def my_local_scope():
# @torch.jit.script
# class MyClass:
# ...
# @torch.jit.script
# class MyClassUsedAsVar:
# ...
# def eg(x: MyClass, y: MyGlobalClass):
# a_local_capture : Foo
# return MyClassUsedAsVar(x)
#
# MyGlobalClass is defined in the __globals__ dictionary of function
# 'eg', so it is always recoverable. my_local_scope introduces a new local
# variable scope in the function. Classes defined here are only visible as
# local variables. For the case of MyClassUsedAsVar, it is captured
# because it is used as a variable inside the body of the function, and we
# can resolve it using the captures returned from `get_closure`. However,
# the type annotations are not captured by the closure. In Python
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
# annotations on `eg``, but starting in Python 4.0, they will represented as
# strings and no longer present. Furthermore, since the body of `eg` does
# not reference those names, they do not appear in the list of closed over
# variables. In Python 2.x, type annotations are in comments, leading to a
# similar situation where their definitions are not available. We anticipate
# that most users will not run into this issue because their modules and
# functions will be defined at a global scope like MyGlobalClass. In cases
# where they are not, it is possible to work around issues by declaring the
# values global in the function.
# In Python 3.9 declaring class as global will make it invisible to
# `inspect.getsource`, see https://bugs.python.org/issue42666 .
# This could be worked around by manualy adding it to `global()` dictionary.
def createResolutionCallbackFromClosure(fn):
"""
Create a resolutionCallback by introspecting the function instead of
looking up the stack for the enclosing scope
"""
closure = get_closure(fn)
class closure_lookup:
# This is a class since `closure` is a dict and it's easier in
# `env_helper` if everything just works with `getattr` calls
def __getattr__(self, key):
if key in closure:
return closure[key]
elif hasattr(typing, key):
return getattr(typing, key)
elif hasattr(builtins, key):
return getattr(builtins, key)
return None
return createResolutionCallbackFromEnv(closure_lookup())
def can_compile_class(cls) -> bool:
# If any of the functions on a type don't have a code object, this type can't
# be compiled and is probably a builtin / bound from C
if is_ignored_fn(cls):
return False
# Ignore the following list of built-in classes.
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
if issubclass(cls, ignored_builtin_classes):
return False
names = cls.__dict__
fns = [
getattr(cls, name)
for name in names
if inspect.isroutine(getattr(cls, name, None))
]
has_code = [hasattr(fn, "__code__") for fn in fns]
return all(has_code)
def get_callable_argument_names(fn) -> List[str]:
"""
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
Returns an empty list when other types of arguments are present.
This is used by `torch.jit.trace` to assign meaningful argument names to
traced functions and modules.
Args:
fn: A callable.
Returns:
Argument names: List[str]
"""
# inspect.signature may fail, give up in that case.
try:
callable_signature = inspect.signature(fn)
except Exception:
return []
argument_names = []
for name, param in callable_signature.parameters.items():
# All four other types of arguments do not map to individual values
# with a keyword as name.
if not param.kind == param.POSITIONAL_OR_KEYWORD:
continue
argument_names.append(name)
return argument_names
def get_annotation_str(annotation):
"""
Convert an AST node containing a type annotation to the string present in the source
that represents the same annotation.
"""
if isinstance(annotation, ast.Name):
return annotation.id
elif isinstance(annotation, ast.Attribute):
return ".".join([get_annotation_str(annotation.value), annotation.attr])
elif isinstance(annotation, ast.Subscript):
# In Python3.9+ subscript indicies are not wrapped in ast.Index
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
elif isinstance(annotation, ast.Tuple):
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
elif isinstance(annotation, (ast.Constant, ast.NameConstant)):
return f"{annotation.value}"
# If an AST node is not handled here, it's probably handled in ScriptTypeParser.
return None
def get_type_hint_captures(fn):
"""
Get a dictionary containing type resolution mappings necessary to resolve types
for the literal annotations on 'fn'. These are not considered to be closed-over by fn
and must be obtained separately (e.g. using this function).
Args:
fn: A callable.
Returns:
A Dict[str, Any] containing a mapping from the literal annotations used on
fn to the Python objects they refer to.
"""
# First, try to get the source of the function. We'll need to parse it to find the actual string names
# that were used to annotate the types, since inspect.signature() will only return the class object that
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
# This may happen in cases where the function is synthesized dynamically at runtime.
src = loader.get_source(fn)
if src is None:
src = inspect.getsource(fn)
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
# types are strings. These are only understood by TorchScript in the context of a type annotation
# that refers to a class in its own definition, but trying to include a mapping for this in the result
# function would cause infinite recursion because the class is currently being compiled.
# In addition, there is logic in ScriptTypeParser to handle this.
signature = inspect.signature(fn)
name_to_type = {
name: parameter.annotation
for name, parameter in signature.parameters.items()
if parameter.annotation is not inspect.Parameter.empty
and not isinstance(parameter.annotation, str)
}
# Then, get the literal type annotations from the function declaration
# by source inspection. This accounts for the case in which aliases are used
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
a = ast.parse(dedent(src))
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
raise RuntimeError(f"Expected {fn} to be a function")
f = a.body[0]
# Prepare a dictionary of source annotation -> type, which will be the final result of this function,
# by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
# them to the type object corresponding to the annotation via name_to_type using the parameter name.
annotation_to_type = {}
for arg in f.args.args:
# Get the source type annotation string for this argument if possible.
arg_annotation_str = (
get_annotation_str(arg.annotation) if arg.annotation else None
)
# If the argument has no annotation or get_annotation_str cannot convert it to a string,
# arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
# this in the latter case.
if arg_annotation_str is None:
continue
# Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
# be present in name_to_type is that the annotation itself is a string and not a type object
# (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
arg_name = arg.arg
if arg_name in name_to_type:
annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
# If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
# the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
# of the annotation cannot be a string.
literal_return_annotation = get_annotation_str(f.returns)
valid_literal_annotation = literal_return_annotation is not None
return_annotation = signature.return_annotation
valid_return_annotation_type = (
return_annotation is not inspect.Parameter.empty
and not isinstance(return_annotation, str)
)
if valid_literal_annotation and valid_return_annotation_type:
annotation_to_type[literal_return_annotation] = return_annotation
return annotation_to_type
def createResolutionCallbackForClassMethods(cls):
"""
This looks at all the methods defined in a class and pulls their closed-over
variables into a dictionary and uses that to resolve variables.
"""
# cls is a type here, so `ismethod` is false since the methods on the type
# aren't bound to anything, so Python treats them as regular functions
fns = [
getattr(cls, name)
for name in cls.__dict__
if inspect.isroutine(getattr(cls, name))
]
# Skip built-ins, as they do not have global scope nor type hints
# Needed to support `enum.Enum` derived classes in Python-3.11
# That adds `_new_member_` property which is an alias to `__new__`
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
captures = {}
for fn in fns:
captures.update(get_closure(fn))
captures.update(get_type_hint_captures(fn))
def lookup_in_class(key):
if key in captures:
return captures[key]
else:
return getattr(builtins, key, None)
return lookup_in_class
def boolean_dispatch(
arg_name, arg_index, default, if_true, if_false, module_name, func_name
):
"""
Dispatches to either of 2 script functions based on a boolean argument.
In TorchScript, the boolean argument must be constant so that the correct
function to use can be determined at compile time.
"""
def fn(*args, **kwargs):
dispatch_flag = default
if arg_name in kwargs:
dispatch_flag = kwargs[arg_name]
elif arg_index < len(args):
dispatch_flag = args[arg_index]
if dispatch_flag:
return if_true(*args, **kwargs)
else:
return if_false(*args, **kwargs)
if if_true.__doc__ is None and if_false.__doc__ is not None:
doc = if_false.__doc__
if_true.__doc__ = doc
elif if_false.__doc__ is None and if_true.__doc__ is not None:
doc = if_true.__doc__
if_false.__doc__ = doc
elif if_false.__doc__ is None and if_true.__doc__ is None:
# neither function has a docstring
doc = None
else:
raise RuntimeError("only one function can have a docstring")
fn.__doc__ = doc
if module_name is not None:
fn.__module__ = module_name
if func_name is not None:
fn.__name__ = func_name
boolean_dispatched[fn] = {
"if_true": if_true,
"if_false": if_false,
"index": arg_index,
"default": default,
"arg_name": arg_name,
}
return fn
class FunctionModifiers:
"""
Used to denote the behavior of a function in TorchScript. See export() and
ignore() for details.
"""
UNUSED = "unused (ignored and replaced with raising of an exception)"
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
EXPORT = "export (compile this function even if nothing calls it)"
DEFAULT = "default (compile if called from a exported function / forward)"
COPY_TO_SCRIPT_WRAPPER = (
"if this method is not scripted, copy the python method onto the scripted model"
)
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
def export(fn):
"""
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
:class:`ScriptModule` and should be compiled.
``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
Functions and methods called from ``forward`` are compiled as they are seen
by the compiler, so they do not need this decorator either.
Example (using ``@torch.jit.export`` on a method):
.. testcode::
import torch
import torch.nn as nn
class MyModule(nn.Module):
def implicitly_compiled_method(self, x):
return x + 99
# `forward` is implicitly decorated with `@torch.jit.export`,
# so adding it here would have no effect
def forward(self, x):
return x + 10
@torch.jit.export
def another_forward(self, x):
# When the compiler sees this call, it will compile
# `implicitly_compiled_method`
return self.implicitly_compiled_method(x)
def unused_method(self, x):
return x - 20
# `m` will contain compiled methods:
# `forward`
# `another_forward`
# `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
"""
fn._torchscript_modifier = FunctionModifiers.EXPORT
return fn
def unused(fn):
"""
This decorator indicates to the compiler that a function or method should
be ignored and replaced with the raising of an exception. This allows you
to leave code in your model that is not yet TorchScript compatible and still
export your model.
Example (using ``@torch.jit.unused`` on a method)::
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self, use_memory_efficient):
super().__init__()
self.use_memory_efficient = use_memory_efficient
@torch.jit.unused
def memory_efficient(self, x):
import pdb
pdb.set_trace()
return x + 10
def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self.use_memory_efficient:
return self.memory_efficient(x)
else:
return x + 10
m = torch.jit.script(MyModule(use_memory_efficient=False))
m.save("m.pt")
m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))
"""
if isinstance(fn, property):
prop = fn
setattr( # noqa: B010
prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
)
if prop.fset:
setattr( # noqa: B010
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
)
return prop
fn._torchscript_modifier = FunctionModifiers.UNUSED
return fn
# No op context manager from python side
class _IgnoreContextManager(contextlib.AbstractContextManager):
def __init__(self, **kwargs):
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
pass
def ignore(drop=False, **kwargs):
"""
This decorator indicates to the compiler that a function or method should
be ignored and left as a Python function. This allows you to leave code in
your model that is not yet TorchScript compatible. If called from TorchScript,
ignored functions will dispatch the call to the Python interpreter. Models with ignored
functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
Example (using ``@torch.jit.ignore`` on a method)::
import torch
import torch.nn as nn
class MyModule(nn.Module):
@torch.jit.ignore
def debugger(self, x):
import pdb
pdb.set_trace()
def forward(self, x):
x += 10
# The compiler would normally try to compile `debugger`,
# but since it is `@ignore`d, it will be left as a call
# to Python
self.debugger(x)
return x
m = torch.jit.script(MyModule())
# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
Example (using ``@torch.jit.ignore(drop=True)`` on a method):
.. testcode::
import torch
import torch.nn as nn
class MyModule(nn.Module):
@torch.jit.ignore(drop=True)
def training_method(self, x):
import pdb
pdb.set_trace()
def forward(self, x):
if self.training:
self.training_method(x)
return x
m = torch.jit.script(MyModule())
# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")
.. testcleanup::
import os
os.remove('m.pt')
"""
if callable(drop):
# used without any args, so drop is actually a function
# @torch.jit.ignore
# def fn(...):
fn = drop
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
if not isinstance(drop, bool):
raise RuntimeError(
"Argument to @torch.jit.ignore must be a bool or "
f"a function but got {drop}"
)
# for backwards compat
drop_on_export = kwargs.pop("drop_on_export", None)
if drop_on_export:
warnings.warn(
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}",
category=FutureWarning,
)
drop = drop_on_export
elif drop:
warnings.warn(
"ignore(True) has been deprecated. TorchScript will now drop the function "
"call on compilation. Use torch.jit.unused now. {}",
category=FutureWarning,
)
def decorator(fn):
if drop:
fn._torchscript_modifier = FunctionModifiers.UNUSED
else:
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
return decorator
def _drop(fn):
fn._torchscript_modifier = FunctionModifiers._DROP
return fn
def _copy_to_script_wrapper(fn):
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
return fn
def module_has_exports(mod):
for name in dir(mod):
if hasattr(mod, name):
item = getattr(mod, name)
if callable(item):
if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
return True
return False
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
# allow JIT'd code to still be covered.
def should_drop(fn) -> bool:
attr = get_torchscript_modifier(fn)
if attr is None:
return False
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
def is_ignored_fn(fn) -> bool:
mod = get_torchscript_modifier(fn)
return (
mod is FunctionModifiers.UNUSED
or mod is FunctionModifiers.IGNORE
or mod is FunctionModifiers._DROP
)
def _is_drop_fn(fn) -> bool:
mod = get_torchscript_modifier(fn)
return mod is FunctionModifiers._DROP
def is_static_fn(cls, fn) -> bool:
return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
def get_static_fn(cls, fn):
return inspect.getattr_static(cls, fn).__func__
def get_torchscript_modifier(fn):
if not callable(fn):
return None
if hasattr(fn, "__func__"):
fn = fn.__func__
return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
def copy_torchscript_modifier(orig, new) -> None:
attr = get_torchscript_modifier(orig)
if attr is None:
return
new._torchscript_modifier = attr
# overloading registration
# overloads get registered in this file, and compiled in torch/jit/__init__.py
# so that they can be imported in nn/functional.py without an import cycle
# qualified_name => list[overload_functions]
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
_OVERLOAD_EXAMPLE = """
Example usage of overload function:
@torch.jit._overload
def my_function(x: type0) -> type0: # decl 1
pass
@torch.jit._overload
def my_function(x: type1) -> type1: # decl 2
pass
def my_function(x): # implementation
if isinstance(x, type0):
return x
elif isinstance(x, type1):
return x
"""
def get_overload_no_implementation_error_message(kind, obj):
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
return (
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
f"sure a definition is provided and defined after all overload declarations.\n"
f'File "{filename}", line {file_lineno}:\n'
+ "".join(sourcelines)
+ "\n"
+ _OVERLOAD_EXAMPLE
)
def _check_overload_body(func):
try:
parsed_def = parse_def(func)
except OSError as e:
# Parsing the function definition can raise an OSError if source is unavailable.
# Since this is just an initial check, just raise a warning if this is the case.
warnings.warn(
f"Unable to retrieve source for @torch.jit._overload function: {func}."
)
return
body = parsed_def.ast.body[0].body
def is_pass(x):
return isinstance(x, ast.Pass)
def is_ellipsis(x):
return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
msg = (
"Only `pass` statement or `...` can be the body of overload declaration:\n"
)
msg += "\n".join(parsed_def.source.split("\n")[:3])
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
raise RuntimeError(msg)
def _overload(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_fns
fn_overload_list = _overloaded_fns.get(qual_name)
if fn_overload_list is None:
fn_overload_list = []
_overloaded_fns[qual_name] = fn_overload_list
fn_overload_list.append(func)
return func
def _get_fn_overloads(qual_name):
return _overloaded_fns.get(qual_name)
def _clear_fn_overloads(qual_name) -> None:
del _overloaded_fns[qual_name]
def get_class_name_lineno(method) -> Tuple[str, int]:
current_frame = inspect.currentframe()
# one for the get_class_name call, one for _overload_method call
for i in range(2):
assert (
current_frame is not None
) # assert current frame is not an Optional[FrameType]
current_frame = current_frame.f_back
assert current_frame is not None # same here
class_name = current_frame.f_code.co_name
line_no = current_frame.f_code.co_firstlineno
return class_name, line_no
# At the the point the decorator is applied to class methods the method
# has no reference to its owning class. _qualified_name would not include
# the class it is defined in, so any methods with the same name in the same file
# would have the same _qualified_name, even if they were defined in different
# classes. This problem only exists in python 2.
# We get around this problem by looking at the stack frame and identifying
# the class name, and throwing an error whenever overloads are used
# when modules of the same name are in the same file
# qualified_name => class name => list[overload_functions]
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
# (qualified_name, class name) => class_fileno
_overloaded_method_class_fileno = {}
def _overload_method(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_methods
class_name_map = _overloaded_methods.get(qual_name, None)
if class_name_map is None:
class_name_map = {}
_overloaded_methods[qual_name] = class_name_map
class_name, line_no = get_class_name_lineno(func)
method_overloads = class_name_map.get(class_name, None)
if method_overloads is None:
method_overloads = []
class_name_map[class_name] = method_overloads
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
else:
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
if existing_lineno != line_no:
raise RuntimeError(
"Cannot currently overload the same method name in two different"
" classes with the same name in the same module"
)
method_overloads.append(func)
return func
def _get_overloaded_methods(method, mod_class):
# TODO: __name__ not set for submodules in recursive script
if not hasattr(method, "__name__"):
return None
qual_name = _qualified_name(method)
class_name_map = _overloaded_methods.get(qual_name, None)
if class_name_map is None:
return None
overloads = class_name_map.get(mod_class.__name__, None)
if overloads is None:
return None
method_line_no = get_source_lines_and_file(method)[1]
mod_class_fileno = get_source_lines_and_file(mod_class)[1]
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
raise Exception(
"Overloads are not useable when a module is redeclared within the same file: "
+ str(method)
)
return overloads
def is_tuple(ann) -> bool:
if ann is Tuple:
raise_error_container_parameter_missing("Tuple")
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
if not hasattr(ann, "__module__"):
return False
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
return True
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
def is_list(ann) -> bool:
if ann is List:
raise_error_container_parameter_missing("List")
if not hasattr(ann, "__module__"):
return False
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
return True
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
def is_dict(ann) -> bool:
if ann is Dict:
raise_error_container_parameter_missing("Dict")
if not hasattr(ann, "__module__"):
return False
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
return True
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
def is_union(ann):
if ann is Union:
raise_error_container_parameter_missing("Union")
return (
hasattr(ann, "__module__")
and ann.__module__ == "typing"
and (getattr(ann, "__origin__", None) is Union)
)
def is_optional(ann):
if ann is Optional:
raise_error_container_parameter_missing("Optional")
def is_optional_as_optional(ann):
return (
hasattr(ann, "__module__")
and ann.__module__ == "typing"
and (getattr(ann, "__origin__", None) is Optional)
)
def is_union_as_optional(ann):
ann_args = ann.__args__
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
def is_future(ann) -> bool:
if ann is Future:
raise RuntimeError(
"Attempted to use Future without a "
"contained type. Please add a contained type, e.g. "
"Future[int]"
)
return getattr(ann, "__origin__", None) is Future
def is_await(ann) -> bool:
if ann is _Await:
return True
return getattr(ann, "__origin__", None) is _Await
if torch.distributed.rpc.is_available():
from torch._C._distributed_rpc import PyRRef
from torch.distributed.rpc import RRef
def is_rref(ann) -> bool:
if ann is RRef:
raise RuntimeError(
"Attempted to use RRef without a "
"contained type. Please add a contained type, e.g. "
"RRef[int]"
)
return getattr(ann, "__origin__", None) is RRef
def is_rref_instance(obj) -> bool:
return isinstance(obj, PyRRef)
else:
def is_rref_instance(obj) -> bool:
# If the RPC module doesn't exist then RRefs don't exist either.
return False
def is_final(ann) -> bool:
return ann.__module__ in {"typing", "typing_extensions"} and (
getattr(ann, "__origin__", None) is Final or isinstance(ann, type(Final))
)
# allows BroadcastingList instance to be subscriptable
class BroadcastingListCls:
def __getitem__(self, types):
return
# mypy doesn't support parameters on types, so we have to explicitly type each
# list size
BroadcastingList1 = BroadcastingListCls()
for i in range(2, 7):
globals()[f"BroadcastingList{i}"] = BroadcastingList1
def is_scripting() -> bool:
r"""
Function that returns True when in compilation and False otherwise. This
is useful especially with the @unused decorator to leave code in your
model that is not yet TorchScript compatible.
.. testcode::
import torch
@torch.jit.unused
def unsupported_linear_op(x):
return x
def linear(x):
if torch.jit.is_scripting():
return torch.linear(x)
else:
return unsupported_linear_op(x)
"""
return False
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
def _qualified_name(obj, mangle_name=True) -> str:
# This special case allows us to override the qualified name on a type.
# It's currently used in conjunction with tracing, where we create a
# fake module to filter only supported attributes. However, since this
# new type is defined as a local class, we need a mechanism to override
# its qualname so it appears correctly in the TorchScript system. This,
# we set '_jit_override_qualname' with the original traced module's
# qualified name, which is picked up here
if hasattr(obj, "_jit_override_qualname"):
return obj._jit_override_qualname
# short-circuit in cases where the object already has a known qualified name
if isinstance(obj, torch._C.ScriptFunction):
return obj.qualified_name
if getattr(obj, "__name__", None):
name = obj.__name__
# Enum classes do not have `__name__` attr, instead they have `name`.
elif isinstance(obj, enum.Enum):
name = obj.name
else:
raise RuntimeError("Could not get name of python class object")
if name == "<lambda>":
name = "_lambda" # make name a valid identifier
module_name = obj.__module__
# If the module is actually a torchbind module, then we should short circuit
if module_name == "torch._classes":
return obj.qualified_name
# The Python docs are very clear that `__module__` can be None, but I can't
# figure out when it actually would be.
if module_name is None:
raise RuntimeError(
f"Could not get qualified name for class '{name}': "
"__module__ can't be None."
)
# if getattr(sys.modules[module_name], name) is not obj:
# raise RuntimeError(f"Could not get qualified name for class '{name}': "
# f"the attr {name} on module {module_name} is not the the class")
# torch.package and TorchScript have separate mangling schemes to avoid
# name collisions from multiple packages. To avoid them interfering with
# each other, normalize the package manging here.
if package_mangling.is_mangled(module_name):
module_name = module_name.replace("<", "_")
module_name = module_name.replace(">", "_")
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
# does not need mangle the python class name.
if mangle_name:
# __main__ is a builtin module, so rewrite it to "__torch__".
if module_name == "__main__":
module_name = "__torch__"
else:
# Everything else gets a "__torch__" prefix to avoid name collisions
# with the names of user values.
module_name = "__torch__." + module_name
if "." in name:
raise RuntimeError(
f"Could not get qualified name for class '{name}': "
f"'{name}' is not a valid identifier"
)
return module_name + "." + name
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
return boolean_dispatched.get(fn)
def _get_named_tuple_properties(
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
):
if loc is None:
loc = fake_range()
assert issubclass(obj, tuple) and hasattr(obj, "_fields")
if hasattr(obj, "_field_defaults"):
defaults = [
obj._field_defaults[field]
for field in obj._fields
if field in obj._field_defaults
]
else:
defaults = []
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
# Also, annotations from base class are not inherited so they need to be queried explicitly
if sys.version_info[:2] < (3, 10):
obj_annotations = getattr(obj, "__annotations__", {})
else:
obj_annotations = inspect.get_annotations(obj)
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
obj_annotations = inspect.get_annotations(obj.__base__)
annotations = []
for field in obj._fields:
if field in obj_annotations:
field_type = obj_annotations[field]
# [Note: ForwardRef annotations in NamedTuple attributes]
# NamedTuple types are slightly different from normal types.
#
# Normally, annotations are evaluted like this (during jit.script):
# 1. Load strings of python code into c++ and parse.
# 2. Get annotations as strings
# 3. Use the PythonResolver's resolution callback (rcb) to convert
# the string into a python object
# 4. We call into annotations.py:ann_to_type to convert python obj
# from step 3 into a type that torchscript understands.
#
# NamedTuples are more complicated, because it has sub-types.
# Normally, once we have the NamedTuple type object from #3,
# we can just look at the annotation literal values and use
# ann_to_type directly on them.
#
# But sometimes, users will annotate with string literals, e.g.
# x: 'int'
# This also happens with PEP563 (from __forward__ import annotations)
#
# These annotations appear in the annotation dict as ForwardRef('int').
#
# Then, we need to convert the string into a python object. This
# requires having local context for custom objects or imported types.
# rcb() is what gives us this. So, we plumb rcb through the stack so
# it can be used in this context for the if block below.
#
# FAQ:
# - Why do we need this special handling for NamedTuple but string
# annotations work fine for normal types? Normally, we parse the
# string directly and then call rcb() directly from C++.
# - Why not use ForwardRef._evaluate? For that, we need globals()
# and locals() for the local context where the NamedTuple was defined.
# rcb is what lets us look up into these. So, basically rcb does the
# hard work for us.
if isinstance(field_type, ForwardRef) and rcb is not None:
rcb_type = rcb(field_type.__forward_arg__)
# rcb returns None if it can't find anything.
if rcb_type is None:
raise ValueError(
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
f" Issue occurred at {loc.highlight()}"
)
field_type = rcb_type
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.getInferred())
return type(obj).__name__, obj._fields, annotations, defaults
def _create_named_tuple(
t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...]
):
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
return TupleType(*t)
@contextlib.contextmanager
def _disable_emit_hooks():
hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
try:
yield
finally:
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
def __enter__(self) -> None:
self.hooks = torch._C._jit_get_emit_hooks()
torch._C._jit_set_emit_hooks(None, None)
def __exit__(self, *args) -> None:
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
def _is_exception(obj) -> bool:
if not inspect.isclass(obj):
return False
return issubclass(obj, Exception)
def raise_error_container_parameter_missing(target_type) -> None:
if target_type == "Dict":
raise RuntimeError(
"Attempted to use Dict without "
"contained types. Please add contained type, e.g. "
"Dict[int, int]"
)
raise RuntimeError(
f"Attempted to use {target_type} without a "
"contained type. Please add a contained type, e.g. "
f"{target_type}[int]"
)
def get_origin(target_type):
return getattr(target_type, "__origin__", None)
def get_args(target_type):
return getattr(target_type, "__args__", None)
def check_args_exist(target_type) -> None:
if target_type is List or target_type is list:
raise_error_container_parameter_missing("List")
elif target_type is Tuple or target_type is tuple:
raise_error_container_parameter_missing("Tuple")
elif target_type is Dict or target_type is dict:
raise_error_container_parameter_missing("Dict")
elif target_type is None or target_type is Optional:
raise_error_container_parameter_missing("Optional")
def check_empty_containers(obj) -> None:
if obj == [] or obj == {} or obj == ():
warnings.warn(
"The inner type of a container is lost when "
"calling torch.jit.isinstance in eager mode. For "
"example, List[int] would become list and "
"therefore falsely return True for List[float] or"
" List[str]."
)
# supports List/Dict/Tuple and Optional types
# TODO support future
def container_checker(obj, target_type) -> bool:
origin_type = get_origin(target_type)
check_args_exist(target_type)
if origin_type is list or origin_type is List:
check_empty_containers(obj)
if not isinstance(obj, list):
return False
arg_type = get_args(target_type)[0]
arg_origin = get_origin(arg_type)
for el in obj:
# check if nested container, ex: List[List[str]]
if arg_origin: # processes nested container, ex: List[List[str]]
if not container_checker(el, arg_type):
return False
elif not isinstance(el, arg_type):
return False
return True
elif origin_type is Dict or origin_type is dict:
check_empty_containers(obj)
if not isinstance(obj, dict):
return False
key_type = get_args(target_type)[0]
val_type = get_args(target_type)[1]
for key, val in obj.items():
# check if keys are of right type
if not isinstance(key, key_type):
return False
val_origin = get_origin(val_type)
if val_origin:
if not container_checker(val, val_type):
return False
elif not isinstance(val, val_type):
return False
return True
elif origin_type is Tuple or origin_type is tuple:
check_empty_containers(obj)
if not isinstance(obj, tuple):
return False
arg_types = get_args(target_type)
if len(obj) != len(arg_types):
return False
for el, el_type in zip(obj, arg_types):
el_origin = get_origin(el_type)
if el_origin:
if not container_checker(el, el_type):
return False
elif not isinstance(el, el_type):
return False
return True
elif origin_type is Union: # also handles Optional
if obj is None: # check before recursion because None is always fine
return True
inner_types = get_args(target_type)
for t in inner_types:
t_origin = get_origin(t)
if t_origin:
return container_checker(obj, t)
elif isinstance(obj, t):
return True
return False
def _isinstance(obj, target_type) -> bool:
if isinstance(target_type, collections.abc.Container):
if not isinstance(target_type, tuple):
raise RuntimeError(
"The second argument to "
"`torch.jit.isinstance` must be a type "
"or a tuple of types"
)
for t_type in target_type:
if _isinstance(obj, t_type):
return True
return False
origin_type = get_origin(target_type)
if origin_type:
return container_checker(obj, target_type)
# Check to handle non-typed optional origin returns as none instead
# of as optional in 3.7-3.8
check_args_exist(target_type)
# handle non-containers
return isinstance(obj, target_type)
class _TensorExtractor(pickle.Pickler):
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
super().__init__(*args, **kwargs)
self.tensors = tensors
def persistent_id(self, obj):
if isinstance(obj, torch.Tensor):
self.tensors.append(obj)
return ""
# Since we just want to extract tensors, we don't mind if an object is
# unpicklable if it doesn't contain tensors, as we can just ignore/skip
# it. To play it safe, we only do so for common objects that we're sure
# don't contain tensors. Feel free to add new types here. Note also that
# even if a type isn't listed here this won't block users, since thet
# can just add a __getstate__ or __reduce__ method to their class.
if isinstance(obj, LockType):
return ""
# Futures and RRefs don't technically contain a value, they just offer
# the means to access a value.
if isinstance(obj, CFuture) or is_rref_instance(obj):
return ""
if isinstance(obj, CAwait):
return ""
if isinstance(obj, torch.cuda.Event):
return ""
if isinstance(obj, threading.Thread):
return ""
return None
def _extract_tensors(obj):
r"""
This function is exclusively called from C++.
See ``torch/csrc/jit/python/python_ivalue.h``.
It extracts the tensors contained in the given object, through pickling.
"""
tensors: List[torch.Tensor] = []
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
extractor.dump(obj)
return tensors
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
# that were previously dropped. To preserve the behavior, explicitly drop them there
if sys.version_info > (3, 10):
_drop(enum.Enum.__new__)
_drop(enum.Enum.__format__)
_drop(enum.Enum.__repr__)
_drop(enum.Enum.__str__)