pytorch/tools/autograd/gen_variable_type.py

2140 lines
81 KiB
Python

# Generates VariableType.h/cpp
#
# **If any changes are being made to the VariableType codegen please also check
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
#
# VariableType is a subclass of at::Type that provides the binding code
# necessary to provide a differentiable version of ATen operators. There are a
# number of different things we could mean:
#
# - Given a non-differentiable forward implementation, we might
# directly associate it with a backward implementation to make
# it differentiable. This is the common case.
#
# - Some functions don't need a backwards implementation, because
# backpropagation will never propagate beyond them. There are a
# number of different reasons why this may be the case:
#
# - The function has no differentiable inputs
# - The function's output is not differentiable
# - The function has no data dependency on its input
#
# - Some function don't need a backwards implementation because they
# are implemented as a composition of other (differentiable) ATen
# functions. These are dispatched directly to the Type superclass,
# which will in turn dispatch back to VariableType for its
# differentiable subcomponents.
#
import re
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from torchgen.api import cpp
from torchgen.api.autograd import (
DifferentiableInput,
dispatch_strategy,
ForwardDerivative,
gen_differentiable_outputs,
is_differentiable,
NativeFunctionWithDifferentiabilityInfo,
SavedAttribute,
)
from torchgen.api.types import (
ArrayRefCType,
BaseCppType,
BaseCType,
Binding,
DispatcherSignature,
intArrayRefT,
iTensorListRefT,
ListCType,
MutRefCType,
OptionalCType,
scalarT,
SpecialArgName,
stringT,
symIntArrayRefT,
TENSOR_LIST_LIKE_CTYPES,
tensorListT,
tensorT,
TupleCType,
VectorCType,
)
from torchgen.code_template import CodeTemplate
from torchgen.context import (
native_function_manager,
with_native_function,
with_native_function_and,
)
from torchgen.model import (
Argument,
BaseType,
ListType,
NativeFunction,
SchemaKind,
SelfArgument,
TensorOptionsArguments,
)
from torchgen.utils import FileManager, mapMaybe
from .context import with_native_function_with_differentiability_info_and_key
from .gen_inplace_or_view_type import (
ALL_VIEW_FUNCTIONS,
ASSIGN_RETURN_VALUE,
AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION,
gen_formals,
get_base_name,
get_view_info,
is_tensor_list_type,
is_tensor_type,
METHOD_DEFINITION,
modifies_arguments,
TMP_VAR,
unpack_args,
unpacked_name,
use_derived,
WRAPPER_REGISTRATION,
)
from .gen_trace_type import (
declare_returned_variables,
get_return_value,
MANUAL_AUTOGRAD_AND_TRACER,
MANUAL_BACKEND,
tie_return_values,
type_wrapper_name,
)
# We don't set or modify grad_fn on these methods. Generally, they return
# tensors that have requires_grad=False. In-place functions listed here will
# not examine or modify requires_grad or grad_fn.
# NB: this does NOT include overload name
DONT_REQUIRE_DERIVATIVE = {
# These only depend on the input Tensor's shape and device, not the data
"empty_like",
"ones_like",
"full_like",
"zeros_like",
"rand_like",
"randn_like",
"new_empty",
"new_empty_strided",
"new_full",
"new_zeros",
"new_ones",
# These are only implemented on integral types
"__and__",
"__iand__",
"__ilshift__",
"__ior__",
"__irshift__",
"__ixor__",
"__lshift__",
"__or__",
"__rshift__",
"__xor__",
# These work on integral data types, and hence don't require derivative
"_sobol_engine_draw",
"_sobol_engine_ff",
"_sobol_engine_scramble_",
"_sobol_engine_initialize_state_",
# This is an unsafe method that is meant to be out of reach of autograd.
"_coalesced_",
# Quantize functions should not record gradients
"quantize_per_tensor",
"quantize_per_channel",
# Functions that return integers should not have output that require gradients
"argmax",
"argmin",
"argsort",
"searchsorted",
"bucketize",
# Functions that return booleans are not differentiable
"isnan",
"isposinf",
"isneginf",
"isinf",
"signbit",
"isin",
"allclose",
# Functions return none are not differentiable
"record_stream",
# These functions are not differentiable
"logical_and",
"logical_xor",
"logical_not",
"logical_or",
# This function returns nested_tensor shape as a tensor that is non-differentiable
"_nested_tensor_size",
"_nested_tensor_strides",
}
# The C -> R functions at the time of adding this are still being audited and tested
# but will not error out.
# C -> C, R -> C functions for which backward is correctly implemented and tested
GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"fill",
"t",
"view",
"reshape",
"reshape_as",
"view_as",
"roll",
"clone",
"block_diag",
"diag_embed",
"repeat",
"expand",
"flip",
"fliplr",
"flipud",
"rot90",
"nanmean",
"nansum",
"transpose",
"permute",
"squeeze",
"unsqueeze",
"resize",
"resize_as",
"tril",
"triu",
"chunk",
"zero_",
"eq_",
"ne_",
"add",
"__radd__",
"sum",
"_conj",
"sin",
"cos",
"mul",
"sinc",
"sinh",
"cosh",
"__rmul__",
"sgn",
"asin",
"acos",
"sub",
"div",
"cat",
"view_as_complex",
"index_put",
"neg",
"complex",
"select",
"where",
"as_strided",
"as_strided_scatter",
"slice",
"constant_pad_nd",
"unbind",
"split",
"split_with_sizes",
"unsafe_split",
"split_with_sizes_backward",
"dot",
"vdot",
"cholesky",
"triangular_solve",
"mm",
"_unsafe_view",
"mv",
"outer",
"bmm",
"diagonal",
"alias",
"atan",
"log",
"log10",
"log1p",
"log2",
"logaddexp",
"logcumsumexp",
"reciprocal",
"tan",
"pow",
"rsqrt",
"tanh",
"tanh_backward",
"asinh",
"acosh",
"atanh",
"take",
"fill_",
"exp",
"exp2",
"expm1",
"nonzero",
"mean",
"std_mean",
"var_mean",
"inverse",
"solve",
"linalg_cholesky",
"addcmul",
"addcdiv",
"matrix_exp",
"linalg_matrix_exp",
"_linalg_eigh",
"cholesky_solve",
"linalg_qr",
"_linalg_svd",
"_fft_c2c",
"_fft_r2c",
"linalg_solve",
"sqrt",
"stack",
"gather",
"index_select",
"index_add_",
"linalg_inv",
"linalg_inv_ex",
"baddbmm",
"addbmm",
"addmm",
"addmv",
"addr",
"linalg_householder_product",
"ormqr",
"reflection_pad1d",
"reflection_pad2d",
"reflection_pad3d",
"linalg_cholesky_ex",
"linalg_eig",
"diagonal_copy",
"diagonal_scatter",
"select_backward",
"diagonal_backward",
"slice_backward",
"reflection_pad1d_backward",
"reflection_pad2d_backward",
"reflection_pad3d_backward",
"_sparse_sparse_matmul",
"replication_pad1d",
"replication_pad2d",
"replication_pad3d",
"put",
"put_",
"_to_copy",
"replication_pad1d_backward",
"replication_pad2d_backward",
"replication_pad3d_backward",
"diag",
"masked_scatter",
"masked_select",
"index_add",
"index_fill",
"trace",
"polar",
"cumsum",
"rsub",
"eig",
"lerp",
"linalg_vector_norm",
"cumprod",
"prod",
"index_copy",
"lu",
"unfold",
"unfold_backward",
"index",
"masked_fill",
"linalg_cross",
"lu_unpack",
"renorm",
"_conj_physical",
"linalg_lu_factor_ex",
"scatter",
"scatter_add",
"sigmoid",
"sigmoid_backward",
"sparse_mask",
"trapezoid",
"cumulative_trapezoid",
"conj_physical_",
"_neg_view",
"_reshape_alias",
"_reshape_copy",
"_linalg_det",
"lu_solve",
"linalg_solve_triangular",
"linalg_pinv",
"linalg_lstsq",
"unfold_copy",
"col2im",
"im2col",
"cholesky_inverse",
"to_sparse",
"sparse_sampled_addmm",
"linalg_lu",
"pixel_shuffle",
"pixel_unshuffle",
"linalg_lu_solve",
"_linalg_slogdet",
"_linalg_solve_ex",
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
"_to_dense",
"_coalesce",
"coalesce",
"values",
"_sparse_coo_tensor_with_dims_and_tensors",
"_sparse_addmm",
}
GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
# Some operators invalidate the grad_accumulator. Let's reset it.
RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
#
# We check the following properties:
# 1) A function should never change the input tensors' underlying c10::TensorImpl
# pointers or c10::Storage pointers, even if it modifies its input tensors (via
# inplace or out-variants)
# If the function does not modify its arguments, we also check the following properties
# pertaining to its output:
# 2) Its TensorImpl has use_count of 1
# 3) If the function is a view function, it has the same StorageImpl as that of
# the input it is aliased with. Otherwise, its StorageImpl has use_count of 1
#
# The following code templates implement the checks for this invariant:
SAVE_TENSOR_STORAGE = CodeTemplate(
"""\
c10::optional<Storage> ${tensor_name}_storage_saved =
${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
"""
)
# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2)
ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
"""\
if (${tensor_name}_storage_saved.has_value() &&
!at::impl::dispatch_mode_enabled() &&
!at::impl::tensor_has_dispatch(${tensor_name}))
TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
"""
)
SAVE_TENSORLIST_STORAGE = CodeTemplate(
"""\
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
for (const Tensor& tensor : ${tensorlist_name})
${tensorlist_name}_storage_saved.push_back(
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
"""
)
ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
"""\
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
}
"""
)
SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
"""\
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
for (const c10::optional<Tensor>& tensor : ${tensorlist_name})
${tensorlist_name}_storage_saved.push_back(
tensor.has_value() && tensor->has_storage() ? c10::optional<Storage>(tensor->storage()) : c10::nullopt);
"""
)
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
"""\
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->storage()));
}
"""
)
SAVE_TENSOR_IMPL = CodeTemplate(
"""\
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
"""
)
ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
"""\
if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
"""
)
ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate(
"""\
if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}");
"""
)
ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate(
"""\
if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) {
TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}");
}
"""
)
SAVE_TENSORLIST_IMPL = CodeTemplate(
"""\
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
for (size_t i=0; i<${tensorlist_name}.size(); i++)
if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
"""
)
ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
"""\
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
}
"""
)
SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate(
"""\
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
c10::optional<Tensor> t = ${tensorlist_name}[i];
if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
}
"""
)
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate(
"""\
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
if (${tensorlist_name}_impl_saved[i])
TORCH_INTERNAL_ASSERT(
${tensorlist_name}_impl_saved[i] == static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
}
"""
)
# The following list contains functions that we don't enforce the invariant on.
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
# These functions are expected to change impl or storage of input tensors
"set_",
"_cudnn_rnn_flatten_weight",
}
DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
# These non-inplace, non-out functions return tensors with use_count > 1
# Therefore, they MAY (but not necessarily) return one of its inputs as-is
# See https://github.com/pytorch/pytorch/issues/60426 for more information
"_embedding_bag",
"_embedding_bag_forward_only",
"q_per_channel_scales",
"q_per_channel_zero_points",
"lu_unpack",
"_cudnn_rnn_backward",
# The below failed StorageImpl use_count check but we skip tensor_impl check
# just in case
"_cudnn_rnn",
"dequantize_self",
# lift() should never actually be called with a requires_grad=True tensor,
"lift",
"lift_fresh",
"lift_fresh_copy",
# Nested Tensors related functions
# _nested_tensor_size() should never actually be called with requires_grad=True tensor
"_nested_tensor_size",
"_nested_tensor_strides",
"_nested_tensor_storage_offsets",
}
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
# These non-view functions return tensors with storage use_count != 1
"_slow_conv2d_forward",
"slow_conv3d_forward",
"channel_shuffle",
# If an input is returned as-is in output, we cannot guarantee its storage_impl
# use count to be 1 either.
*DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
}
# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ]
DECLARE_GRAD_FN = CodeTemplate(
"""\
std::shared_ptr<${op}> grad_fn;
"""
)
DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate(
"""\
std::vector<std::shared_ptr<${op}>> grad_fns;
"""
)
SETUP_ANY_REQUIRES_GRAD = CodeTemplate(
"""\
[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} );
${extra_differentiability_conditions}
"""
)
SETUP_DERIVATIVE = CodeTemplate(
"""\
if (_any_requires_grad) {
${setup}
}
"""
)
SETUP_NONE_REQUIRES_GRAD = CodeTemplate(
"""\
if (compute_requires_grad( ${args_to_check} )) {
throw_error_out_requires_grad("${base_name}");
}
"""
)
ASSIGN_GRAD_FN = CodeTemplate(
"""\
grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
"""
)
# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i`
# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars.
ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate(
"""\
for (const auto& i : c10::irange( ${irange} )) {
const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives});
check_inplace(self[i], ith_requires_grad);
grad_fns.push_back([&]() -> std::shared_ptr<${op}> {
if (!ith_requires_grad) {
return nullptr;
} else {
auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
return grad_fn;
}
}());
}
"""
)
CALL_REDISPATCH = CodeTemplate(
"""\
at::redispatch::${api_name}(${unpacked_args})"""
)
# If the non-variable operation has return values, we use the `tmp` variable to hold the
# values temporarily and pass the values to the return variables outside of the
# `at::AutoDispatchBelowAutograd` guard block.
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate(
"""\
auto ${tmp_var} = ([&]() {
if (${any_has_forward_grad}) {
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
} else {
${guard}
return ${base_type_call};
}
})();
"""
)
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
"""\
auto ${tmp_var} = ([&]() {
${guard}
return ${base_type_call};
})();
"""
)
DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
"""\
{
${guard}
${base_type_call};
}
"""
)
SET_HISTORY = CodeTemplate(
"""\
if (grad_fn) {
${fn}_history(${differentiable_outputs}, grad_fn);
}
"""
)
LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate(
"""\
if (!grad_fns.empty()) {
${preamble}
for (const auto& i : c10::irange(grad_fns.size())) {
auto grad_fn = grad_fns[i];
if (grad_fn != nullptr) {
${statements}
}
}
}
"""
)
CONDITIONAL = CodeTemplate(
"""\
if (${cond}) {
${statements}
}
"""
)
RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
"""\
#ifndef NDEBUG
${statements}
#endif
"""
)
FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate(
"""\
isFwGradDefined(${req_inp})\
"""
)
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
"""\
isFwGradDefinedTensorList(${req_inp})\
"""
)
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
"""\
auto ${inp_name}_t_raw = toNonOptFwGrad(${inp});
auto ${inp_name}_tensor = toNonOptTensor(${inp});
auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined())
? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sizes(), ${inp_name}_tensor.options());
"""
)
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
"""\
auto ${inp_name}_p = toNonOptPrimal(${inp});
"""
)
FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate(
"""\
if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
}
"""
)
FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate(
"""\
for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) {
auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i];
if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
}
}
"""
)
FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate(
"""\
if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()
&& ${out_arg}.defined()) {
${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false);
}
"""
)
FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
"""\
if (${out_arg}_new_fw_grad_opt.has_value()) {
auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
for (const auto i : c10::irange(${out_arg}.size())) {
if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels.
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});
}
}
}
"""
)
FW_DERIVATIVE_TEMPLATE = CodeTemplate(
"""\
${fw_grad_opt_definition}
if (${requires_fw_grad}) {
${unpacked_arguments}
${out_arg}_new_fw_grad_opt = ${formula};
}
"""
)
FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate(
"""\
${fw_grad_opt_definition}
for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) {
if (${any_has_forward_grad_for_current_index}) {
${unpacked_arguments}
${vector_of_optional_tensor}[i] = ${formula};
}
}
"""
)
FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate(
"""\
TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
"""
)
FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate(
"""\
for (const auto& _t: ${arg}) {
TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
}
"""
)
def gen_variable_type(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
used_keys: Set[str],
) -> None:
"""VariableType.h and VariableType.cpp body
This is the at::Type subclass for differentiable tensors. The
implementation of each function dispatches to the base tensor type to
compute the output. The grad_fn is attached to differentiable functions.
"""
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write(
"VariableType.h",
lambda: {
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/VariableType.h"
},
)
# helper that generates a TORCH_LIBRARY_IMPL macro for each
# dispatch key that appears in derivatives.yaml
def wrapper_registrations(used_keys: Set[str]) -> str:
library_impl_macro_list: List[str] = []
for key in sorted(used_keys):
dispatch_key = key
if key == "Default":
dispatch_key = "Autograd"
library_impl_macro = (
f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) "
+ "{\n"
+ "${"
+ f"wrapper_registrations_{key}"
+ "}\n}"
)
library_impl_macro_list += [library_impl_macro]
return "\n\n".join(library_impl_macro_list)
# Generate a new template from VariableType.cpp which replaces ${wrapper_registrations}
# with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml
fm1 = FileManager(
install_dir=out + "/templates", template_dir=template_path, dry_run=False
)
fm1.write(
"VariableType.cpp",
lambda: {
"type_derived_method_definitions": "\n\n".join(
[
"${" + f"type_derived_method_definitions_{key}" + "}"
for key in sorted(used_keys)
]
),
"wrapper_registrations": wrapper_registrations(used_keys),
},
)
# Generate final VariableType_*.cpp files from the generated template
fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)
sharded_keys = set(
[f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
+ [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
)
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.
fm2.write_sharded(
"VariableType.cpp",
[fn for fn in fns_with_diff_infos if use_derived(fn)],
key_fn=lambda fn: cpp.name(fn.func.func),
base_env={
"generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/VariableType.cpp",
},
env_callable=gen_variable_type_func,
num_shards=5,
sharded_keys=sharded_keys,
)
@with_native_function_and
def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
return WRAPPER_REGISTRATION.substitute(
unqual_operator_name_with_overload=f.func.name,
type_wrapper_name=type_wrapper_name(f, key),
class_type="VariableType",
)
def gen_variable_type_func(
fn: NativeFunctionWithDifferentiabilityInfo,
) -> Dict[str, List[str]]:
f = fn.func
result = {}
with native_function_manager(f):
name = cpp.name(f.func)
formals = gen_formals(f)
if (
fn.info is None
and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR
and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE
and len(gen_differentiable_outputs(fn)) > 0
and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
):
# NOTE: [ Registering AutogradNotImplemented boxed kernel ]
#
# When there is no derivatives.yaml entry, we register a generic boxed
# NotImplemented kernel to set grad_fn to be NotImplemented, so that forward
# proceeds as usual but an error is properly produced on backward.
# TODO: it would be nice to not have these special cases
#
# There are several cases where still let codegen handle it:
# 1) ops that need to reset grad accumulator (we let codegen handle this case
# because) the list is (currently) only accessible in Python.
# 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes
# autograd a fallthrough with NDEBUG checks. This can be useful for when all
# outputs are integral.
# 3) When there are no differentiable outputs. This is similar to (2).
# 4) There are certain ops where we skip certain NDEBUG checks. this is similar
# to (1).
type_definition = ""
wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute(
unqual_operator_name_with_overload=f.func.name
)
result["type_derived_method_definitions_Default"] = [type_definition]
result["wrapper_registrations_Default"] = [wrapper_registration]
else:
if not fn.info:
key = "Default"
type_definition = METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(
f.func.returns, symint=True
).cpp_type(),
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
)
wrapper_registration = gen_wrapper_registration(f, key)
result[f"type_derived_method_definitions_{key}"] = [type_definition]
result[f"wrapper_registrations_{key}"] = [wrapper_registration]
else:
for key in fn.info.keys():
type_definition = METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(
f.func.returns, symint=True
).cpp_type(),
type_wrapper_name=type_wrapper_name(f, key),
type_definition_body=emit_body(fn, key),
formals=formals,
)
wrapper_registration = gen_wrapper_registration(f, key)
result[f"type_derived_method_definitions_{key}"] = [type_definition]
result[f"wrapper_registrations_{key}"] = [wrapper_registration]
# See Note [Manual Backend kernels]
assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
# If you want to register a kernel to Autograd, you must make the op abstract.
# In other words, this op must have dispatch section in native_functions.yaml.
if name in MANUAL_AUTOGRAD_AND_TRACER or (
fn.info and any(info.has_derivatives for info in fn.info.values())
):
msg = (
f"There's a formula for {name}(or its functional variant) in derivatives.yaml. "
f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA "
f"or CompositeExplicitAutograd in native_functions.yaml. Please see "
f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword "
f"for instructions to choose the right dispatch keyword."
)
assert f.is_abstract, msg
return result
_foreach_ops_without_differentiability_info = {
# No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`.
("_foreach_maximum", "Scalar"),
("_foreach_maximum", "ScalarList"),
("_foreach_minimum", "Scalar"),
("_foreach_minimum", "ScalarList"),
# No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor.
("_foreach_addcdiv", "Tensor"),
("_foreach_addcmul", "Tensor"),
("_foreach_copy", ""),
}
_foreach_ops_with_different_arity = {
# These ops lack `alpha` of scaling factor to applied to the right hand side argument.
("_foreach_add", "Scalar"),
("_foreach_add", "ScalarList"),
("_foreach_sub", "Scalar"),
("_foreach_sub", "ScalarList"),
}
@with_native_function_with_differentiability_info_and_key
def emit_body(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[str]:
assert dispatch_strategy(fn) == "use_derived"
f = fn.func
info = fn.info[key] if fn.info else None
fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else []
name = cpp.name(f.func)
inplace = f.func.kind() == SchemaKind.inplace
is_out_fn = f.func.kind() == SchemaKind.out
returns_void = len(f.func.returns) == 0
base_name = get_base_name(f)
view_info = get_view_info(f)
is_foreach = name.startswith("_foreach")
is_inplace_foreach = is_foreach and inplace
if is_inplace_foreach:
inplace_foreacharg2refarg: Dict[Argument, Argument] = {}
refargname2inplace_foreacharg: Dict[str, Argument] = {}
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
if info is None:
assert (
base_name_and_overload_name
in _foreach_ops_without_differentiability_info
), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
else:
assert (
len(f.func.arguments.flat_non_out)
== len(info.func.func.arguments.flat_non_out)
) or (base_name_and_overload_name in _foreach_ops_with_different_arity), (
f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args "
f"but the reference has {len(info.func.func.arguments.flat_non_out)}"
)
for foreach_arg, ref_arg in zip(
f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out
):
foreach_arg_type = foreach_arg.type
if isinstance(foreach_arg_type, ListType):
foreach_arg_type = foreach_arg_type.elem
assert foreach_arg_type == ref_arg.type
inplace_foreacharg2refarg[foreach_arg] = ref_arg
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
def gen_differentiable_input(
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
) -> Optional[DifferentiableInput]:
if isinstance(arg, TensorOptionsArguments):
return None
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
# TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
# NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
# not handled properly as they are irrelevant for this codegen.
cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
if not is_differentiable(a.name, a.type, info):
return None
return DifferentiableInput(
name=a.name,
type=a.type,
cpp_type=cpp_type,
)
@with_native_function
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
arguments = list(f.func.arguments.non_out)
if is_inplace_foreach and info is not None:
for i, arg in enumerate(f.func.arguments.flat_non_out):
if arg in inplace_foreacharg2refarg:
# note(crcrpar): From what I understand, what matters is only the name.
# Thus originally I only replace argument only when the names are different.
# TODO(crcrpar): Make it simpler.
mapped_arg = inplace_foreacharg2refarg[arg]
arguments[i] = Argument(
mapped_arg.name,
mapped_arg.type,
mapped_arg.default,
mapped_arg.annotation,
)
return list(mapMaybe(gen_differentiable_input, arguments))
def find_args_with_derivatives(
differentiable_inputs: List[DifferentiableInput],
) -> List[DifferentiableInput]:
"""Find arguments that have derivative definitions"""
if info is None or not info.has_derivatives:
return differentiable_inputs
names = {name for d in info.derivatives for name in d.var_names}
differentiable = [arg for arg in differentiable_inputs if arg.name in names]
if len(differentiable) != len(names):
missing = names - {arg.name for arg in differentiable}
raise RuntimeError(
f"Missing arguments for derivatives: {missing} in {info.name}"
)
return differentiable
differentiable_inputs = gen_differentiable_inputs(f)
args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
differentiable_outputs = gen_differentiable_outputs(fn, key)
undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
name in DONT_REQUIRE_DERIVATIVE
)
requires_derivative = (
(not undifferentiable)
and (len(differentiable_inputs) > 0)
and (
(len(differentiable_outputs) > 0)
# note(crcrpar): In-place foreach functions are a void function.
or is_inplace_foreach
)
)
if (
info is not None
and info.has_derivatives
and not requires_derivative
# out= ops are allowed to have zero returns which cause requires_derivative to be False
# we shouldn't error out though (out= ops for autograd just redispatch)
and len(f.func.returns) > 0
):
raise RuntimeError(
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
)
# note(crcrpar): In-place foreach functions do not support forward AD
if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach:
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
differentiable_outputs
), (
"Expected the number of forward derivatives implemented to match the "
"number of differentiable outputs. NB: This only applies when at least "
"one forward derivative is implemented. Not implementing any forward "
"derivatives is also okay, and we would require inputs to the op to "
"not have associated tangents in that case."
)
try_jit_decomposition = (
requires_derivative
and len(fw_derivatives) == 0
and (not modifies_arguments(f))
and (not returns_void)
)
def emit_save_inputs() -> List[str]:
setup: List[str] = []
if info is None or not info.has_derivatives:
return setup
has_tensorlist_arg = any(
is_tensor_list_type(arg.type) for arg in args_with_derivatives
)
# We don't want to save tensors if we know that they will never be used
# when computing the derivative, so we add guards to those statements
def guard_for(arg: SavedAttribute) -> Optional[str]:
assert info is not None
# It's hard to determine the edge offset if we have TensorLists
# NOTE(crcrpar): in-place foreach functions' arguments include tensorlist
# but their derivatives don't use it, so let them bypass this check.
if has_tensorlist_arg and (not is_inplace_foreach):
return None
# Empirical evaluation of the cases where we insert those guards in
# backward show that they are somewhat useless. E.g. there's no need
# to guard on some values captured from forward, because they had to
# require_grad if the backward function even gets executed. I don't
# have any good ideas for detecting those cases, so I simply disabled the
# checks.
if "backward" in info.name:
return None
# If there's a single derivative we could compute, we already have
# a requires_grad check that is sufficient
if len(args_with_derivatives) <= 1:
return None
# We really only care about trimming down the amount of tensors we save
if arg.nctype.type != BaseCType(tensorT):
return None
# We want to emit simple guards, so we only allow that if checking one
# input is enough to determine whether we need that value
used_in = [d for d in info.derivatives if arg in d.saved_inputs]
assert len(used_in) > 0
if len(used_in) != 1:
return None
derivative = used_in[0]
# Case with multioutput formulas
# TODO: process all derivative formulas!!!
if len(derivative.var_names) != 1:
wrap_opt_if_start = derivative.formula.find(
f"wrap_opt_if({arg.nctype.name}"
)
if wrap_opt_if_start == -1:
return None
wrap_opt_if_match = re.match(
rf"wrap_opt_if\({arg.nctype.name},(.*?)\)",
derivative.formula[wrap_opt_if_start:],
)
assert wrap_opt_if_match is not None
# Condition is between 'wrap_opt_if(var_name,' and ')'.
condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1)
wrap_opt_if_condition = wrap_opt_if_match.group(0)[
condition_slice
].strip()
# replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)'
wrap_opt_if_condition = re.sub(
r"grad_input_mask\[(\d+)\]",
r"grad_fn->should_compute_output(\1)",
wrap_opt_if_condition,
)
return f"{wrap_opt_if_condition}"
# Figure out the offset of the edge that uses this variable
derivative_var_name = derivative.var_names[0]
for edge_off, a in enumerate(args_with_derivatives):
if a.name == derivative_var_name:
break
else:
raise AssertionError()
return f"grad_fn->should_compute_output({edge_off})"
if is_inplace_foreach:
save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for)
if save_input_stmts:
setup.append(
LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
preamble="", statements=save_input_stmts
)
)
else:
setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
for arg in args_with_derivatives:
if is_tensor_list_type(arg.type):
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
return setup
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
body: List[str] = []
if is_out_fn:
# For out functions, ensure that no input or output requires grad
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
body.append(
SETUP_NONE_REQUIRES_GRAD.substitute(
base_name=base_name,
args_to_check=[arg.name for arg in differentiable_inputs],
)
)
body.append(
SETUP_NONE_REQUIRES_GRAD.substitute(
base_name=base_name,
args_to_check=[arg.name for arg in differentiable_outputs],
)
)
return body
op = info.op if info is not None and info.has_derivatives else "NotImplemented"
setup = []
if not is_inplace_foreach:
setup.extend(
ASSIGN_GRAD_FN.substitute(
op=op,
op_ctor=""
if info is not None and info.has_derivatives
else f'"{cpp.name(f.func)}"',
args_with_derivatives=[arg.name for arg in args_with_derivatives],
).split("\n")
)
else:
# note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList.
list_like_arg = "self"
args = [arg.name for arg in args_with_derivatives]
for i, arg in enumerate(args):
if is_inplace_foreach and info is not None:
if arg in refargname2inplace_foreacharg:
foreach_arg = refargname2inplace_foreacharg[arg]
args[i] = foreach_arg.name + (
"[i]" if isinstance(foreach_arg.type, ListType) else ""
)
else:
if arg == list_like_arg:
args[i] = arg + "[i]"
setup.extend(
ASSIGN_VECTOR_OF_GRAD_FN.substitute(
op=op,
op_ctor=""
if info is not None and info.has_derivatives
else f'"{cpp.name(f.func)}"',
args_with_derivatives=args,
irange=f"{list_like_arg}.size()",
).split("\n")
)
setup.extend(emit_save_inputs())
body.extend(
emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)
)
declare_grad_fn_template = (
DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN
)
body.append(declare_grad_fn_template.substitute(op=op))
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
return body
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
body: List[str] = []
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
return body
for arg in differentiable_outputs:
name = arg.name
# TODO: should be `arg.type.is_tensor_like()`?
if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES:
body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
return body
def emit_check_no_requires_grad(
tensor_args: List[DifferentiableInput],
args_with_derivatives: List[DifferentiableInput],
) -> List[str]:
"""Checks that arguments without derivatives don't require grad"""
body: List[str] = []
for arg in tensor_args:
if arg in args_with_derivatives:
continue
arg_name = arg.name
if info and arg_name in info.non_differentiable_arg_names:
continue
if arg_name == "output":
# Double-backwards definitions sometimes take in 'input' and
# 'output', but only define the derivative for input.
continue
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
return body
def emit_original_self_definition() -> List[str]:
body: List[str] = []
if inplace:
if is_inplace_foreach:
body.append(
"std::vector<c10::optional<at::Tensor>> original_selfs(self.size());"
)
else:
body.append("c10::optional<at::Tensor> original_self;")
all_forward_grad_cond = []
for derivative in fw_derivatives:
if derivative.required_original_self_value:
all_forward_grad_cond.append(
get_any_has_forward_grad_name(derivative.var_names)
)
if all_forward_grad_cond:
if not is_inplace_foreach:
body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{')
body.append(" original_self = self.clone();")
body.append("}")
else:
current_all_forward_grad_cond = [
f"{cond}[i]" for cond in all_forward_grad_cond
]
body.append("for (const auto& i : c10::irange(self.size())) {")
body.append(
f" if ({' || '.join(current_all_forward_grad_cond)}) {{"
)
body.append(" original_selfs[i] = self[i].clone();")
body.append(" }")
body.append("}")
return body
def save_variables(
saved_variables: Sequence[SavedAttribute],
is_output: bool,
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None,
) -> Sequence[str]:
# assign the saved variables to the generated grad_fn
stmts: List[str] = []
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
name = (
arg.nctype.name.name
if isinstance(arg.nctype.name, SpecialArgName)
else arg.nctype.name
)
foreacharg: Optional[Argument] = None
is_foreacharg_list_type: bool = False
type = arg.nctype.type
expr = arg.expr
stmts_prepend = None
if is_inplace_foreach and info is not None:
# todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`.
# for now the example assert would fail.
name_to_query = name.split("_scalar_type")[0]
if name_to_query in refargname2inplace_foreacharg:
foreacharg = refargname2inplace_foreacharg[name_to_query]
is_foreacharg_list_type = isinstance(foreacharg.type, ListType)
if foreacharg is not None:
name_in_expr = (
f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}"
)
src_name = name
if "_scalar_type" in src_name:
split_src_name = src_name.split("_scalar_type")
assert len(split_src_name) == 2
src_name = split_src_name[0]
expr = expr.replace(src_name, name_in_expr)
if (
type == BaseCType(tensorT)
or type == OptionalCType(BaseCType(tensorT))
or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
or (is_output and type == BaseCType(scalarT))
):
# note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored.
var = name
name += "_"
if var == "self" and inplace:
original_self_var = (
"original_self"
if not is_inplace_foreach
else "original_selfs[i]"
)
self_var = var if not is_inplace_foreach else var + "[i]"
stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()"
var = f"{original_self_var}.value()"
assert not is_output
if inplace and is_output:
assert name == "result_"
var = (
"self[i]"
if is_inplace_foreach or is_foreacharg_list_type
else "self"
)
is_inplace_view = f"{var}.is_view()"
expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})"
else:
expr = f"SavedVariable({var}, {str(is_output).lower()})"
if foreacharg is not None and "original_selfs" not in expr:
expr = expr.replace(src_name, name_in_expr)
elif (
type == BaseCType(tensorListT)
or type == ListCType(OptionalCType(BaseCType(tensorT)))
or type == BaseCType(iTensorListRefT)
or type == VectorCType(BaseCType(tensorT))
):
# See Note [nuanced return type of out-of-place foreach functions]
if type == VectorCType(BaseCType(tensorT)):
assert is_foreach and is_output
expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})"
name += "_"
elif type == BaseCType(intArrayRefT):
expr = expr + ".vec()"
elif type == BaseCType(symIntArrayRefT):
expr = expr + ".vec()"
elif type == BaseCType(stringT):
expr = f"std::string({expr})"
elif type == OptionalCType(BaseCType(stringT)):
expr = f"{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt"
elif type == ArrayRefCType(
elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
):
expr = expr + ".vec()"
guard = guard_for(arg)
if guard is None:
if stmts_prepend:
stmts.append(f"{stmts_prepend};")
stmts.append(f"grad_fn->{name} = {expr};")
else:
stmts.append(f"if ({guard}) {{")
if stmts_prepend:
stmts.append(f" {stmts_prepend};")
stmts.append(f" grad_fn->{name} = {expr};")
stmts.append("}")
return stmts
# Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
# - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
# - redispatch() avoids a redundant call to RecordFunction, which was already called right before
# we entered this autograd kernel.
def emit_dispatch_call(
f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
) -> str:
"""Dispatch call via function in a namespace or method on Tensor."""
dispatcher_sig = DispatcherSignature.from_schema(f.func)
dispatcher_exprs = dispatcher_sig.exprs()
# code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
# Ops also always have a function variant of the redispatch API.
# See Note [Plumbing Keys Through The Dispatcher] for details.
dispatch_key_set = "ks & c10::after_autograd_keyset"
call = CALL_REDISPATCH.substitute(
api_name=cpp.name(
f.func,
faithful_name_for_out_overloads=True,
symint_overload=f.func.has_symint(),
),
unpacked_args=[dispatch_key_set] + list(unpacked_args),
)
return call
def wrap_output(
f: NativeFunction, unpacked_bindings: List[Binding], var: str
) -> str:
call = ""
rhs_value: Optional[str] = None
if not any(r.type.is_tensor_like() for r in f.func.returns):
rhs_value = var
else:
rhs_value = f"std::move({var})"
assert rhs_value is not None
call += ASSIGN_RETURN_VALUE.substitute(
return_values=tie_return_values(f), rhs_value=rhs_value
)
return call
def check_tensorimpl_and_storage(
call: str, unpacked_bindings: List[Binding]
) -> str:
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
stmts_before_call: List[str] = []
stmts_after_call: List[str] = []
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
return call
# Check properties of inputs (enforce (1))
for unpacked_binding in unpacked_bindings:
arg = unpacked_binding.name
noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType(
iTensorListRefT
):
stmts_before_call += [
SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
]
stmts_after_call += [
ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
]
elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
stmts_before_call += [
SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg),
]
stmts_after_call += [
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
tensorlist_name=arg
),
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
tensorlist_name=arg
),
]
elif noref_cpp_type == BaseCType(tensorT):
stmts_before_call += [
SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
SAVE_TENSOR_IMPL.substitute(tensor_name=arg),
]
stmts_after_call += [
ENFORCE_SAME_TENSOR_STORAGE.substitute(
tensor_name=arg, out_tensor_name=arg
),
ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg),
]
assert (stmts_before_call and stmts_after_call) or (
not stmts_before_call and not stmts_after_call
)
# Check properties of outputs (enforce (2), (3))
if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out):
base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)?
aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
if aliased_arg_name is not None:
aliased_arg_name = unpacked_name(aliased_arg_name)
for i, (ret, ret_name) in enumerate(
zip(f.func.returns, cpp.return_names(f))
):
noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
if noref_cpp_type == BaseCType(tensorT):
if aliased_arg_name is not None:
assert (
i == 0
), "Expect non-CompositeImplicitAutograd view function {base} to return single output"
stmts_after_call += [
ENFORCE_SAME_TENSOR_STORAGE.substitute(
tensor_name=aliased_arg_name, out_tensor_name=ret_name
)
]
else:
if (
type_wrapper_name(f)
not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
):
stmts_after_call += [
ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
tensor_name=ret_name, fn_name=type_wrapper_name(f)
)
]
if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
stmts_after_call += [
ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
tensor_name=ret_name, fn_name=type_wrapper_name(f)
)
]
# Currently we don't have any functions that return the following types, but
# we should update the checks once we do
elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
raise AssertionError(
f"Please add use_count checks for {noref_cpp_type}"
)
elif noref_cpp_type == BaseCType(tensorListT):
raise AssertionError(
f"Please add use_count checks for {noref_cpp_type}"
)
if stmts_before_call and stmts_after_call:
call = (
RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call)
+ call
+ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
)
return call
def emit_call(
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
) -> str:
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
# in are now Variables.
# See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
unpacked_args = [b.name for b in unpacked_bindings]
base_type_call = emit_dispatch_call(f, "self_", unpacked_args)
if get_view_info(f) is not None or modifies_arguments(f):
guard = "at::AutoDispatchBelowAutograd guard;"
else:
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
any_has_forward_grad = (
get_any_has_fw_grad_cond(derivative=None)
if requires_derivative
else "false"
)
return_types = ", ".join(
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
)
if len(f.func.returns) > 1:
return_types = f"std::tuple<{return_types}>"
arg_names = [
a.name
for a in cpp.arguments(
f.func.arguments,
faithful=True,
symint=True,
method=False,
cpp_no_default_args=set(),
)
]
if not modifies_arguments(f) and not returns_void:
if try_jit_decomposition:
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute(
base_type_call=base_type_call,
tmp_var=TMP_VAR,
guard=guard,
any_has_forward_grad=any_has_forward_grad,
op_name=cpp.name(f.func),
op_overload=f.func.name.overload_name,
return_types=return_types,
arg_names=arg_names,
)
else:
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
base_type_call=base_type_call,
tmp_var=TMP_VAR,
guard=guard,
)
call += wrap_output(f, unpacked_bindings, TMP_VAR)
else:
assert not try_jit_decomposition
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
base_type_call=base_type_call, guard=guard
)
call = check_tensorimpl_and_storage(call, unpacked_bindings)
return call
def emit_history() -> str:
fn = "rebase" if modifies_arguments(f) and view_info is None else "set"
output_names = [r.name for r in differentiable_outputs]
# TODO: flatten allocates a std::vector, which could be expensive
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
outs=output_names if not is_inplace_foreach else "self"
)
if not is_inplace_foreach:
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
else:
return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
preamble=(
f"auto differentiable_outputs = {outs};\n"
f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());"
),
statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);",
)
def emit_save_outputs() -> str:
if is_out_fn:
# out functions don't currently support differentiation
return ""
if info is not None and info.has_derivatives:
stmts = save_variables(info.all_saved_outputs, True)
if len(stmts) == 0:
return ""
if not is_inplace_foreach:
return CONDITIONAL.substitute(cond="grad_fn", statements=stmts)
else:
return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
preamble="", statements=stmts
)
return ""
def emit_any_requires_grad() -> List[str]:
extra_condition = ""
if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1
extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});"
names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives]
if is_inplace_foreach and info is not None:
for i, arg in enumerate(names_of_args_with_derivatives):
for f_arg, r_arg in inplace_foreacharg2refarg.items():
if arg == r_arg.name:
names_of_args_with_derivatives[i] = f_arg.name
return [
SETUP_ANY_REQUIRES_GRAD.substitute(
args_with_derivatives=names_of_args_with_derivatives,
extra_differentiability_conditions=extra_condition,
)
]
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
if len(var_names) == 1:
return f"_any_has_forward_grad_{var_names[0]}"
else:
return f'_any_has_forward_grad_{"_".join(var_names)}'
def emit_any_has_forward_grad() -> List[str]:
content: List[str] = []
if not is_foreach:
for derivative in fw_derivatives:
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
content.append(
f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};"
)
else:
for derivative in fw_derivatives:
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
cur_derivative_conditions = [
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
req_inp=(
inp.name
if not inplace
else refargname2inplace_foreacharg[inp.name].name
)
+ (
"[i]"
if is_tensor_list_type(
inp.type
if not inplace
else refargname2inplace_foreacharg[inp.name].type
)
else ""
),
)
for inp in differentiable_inputs
if derivative.required_inputs_fw_grad is not None
and inp.name in derivative.required_inputs_fw_grad
]
content.append(f"std::vector<bool> {bool_vector_name}(self.size());")
content.append("for (const auto& i : c10::irange(self.size())) {")
content.append(
f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};"
)
content.append("}")
return content
def emit_check_inplace() -> List[str]:
if not inplace:
return []
return [
f"check_inplace({arg.name}, _any_requires_grad);"
for arg in differentiable_outputs
]
def emit_fw_derivatives() -> List[str]:
content: List[str] = []
fw_grad_setters: List[str] = []
for derivative in fw_derivatives:
res = derivative.var_names
if f.func.name.name.inplace:
assert (
len(res) == 1
), "Expected number of outputs to be 1 if function is inplace"
# TODO update this when inplace namings are unified
res = ("self",)
assert derivative.required_inputs_fw_grad is not None
unpacked_arguments = ""
for inp in differentiable_inputs:
inp_name = inp.name
is_input_tensorlist = is_foreach and is_tensor_list_type(
inp.type
if not inplace
else refargname2inplace_foreacharg[inp.name].type
)
input_suffix = "[i]" if is_input_tensorlist else ""
if is_inplace_foreach:
if inp.name in refargname2inplace_foreacharg:
inp_name = refargname2inplace_foreacharg[inp.name].name
zeros_fn = (
"zeros"
if inplace and inp.name == "self"
else "_efficientzerotensor"
)
if inp.name in derivative.required_inputs_fw_grad:
unpacked_arguments += (
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
inp_name=inp.name,
inp=inp_name + input_suffix,
zeros_fn=zeros_fn,
)
)
if inp.name in (derivative.required_inputs_primal or []):
unpacked_arguments += (
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
inp_name=inp.name,
inp=inp_name + input_suffix,
)
)
if derivative.required_original_self_value:
input_suffix = "s[i]" if is_inplace_foreach else ""
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
inp_name="original_self",
inp="original_self" + input_suffix,
zeros_fn=zeros_fn,
)
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
inp_name="original_self",
inp="original_self" + input_suffix,
)
elif inplace and derivative.is_reusing_outplace_formula:
# The gradient wasn't already cloned, do it if grad mode is enabled
unpacked_arguments += (
"self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
)
if inplace:
is_inplace_str = "true"
else:
is_inplace_str = "false"
requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names)
if all(
(isinstance(var_type, BaseType) and var_type.is_tensor_like())
for var_type in derivative.var_types
):
# Is there a way to get from BaseType to BaseCType
if len(derivative.var_types) == 1:
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
if not is_foreach:
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
else:
assert res[0] == ("result" if not inplace else "self")
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
requires_fw_grad += f" && ({derivative.var_names[0]}.defined())"
else:
tuple_type = TupleCType(
[BaseCType(tensorT)] * len(derivative.var_types)
)
opt_res_grad_type = OptionalCType(tuple_type).cpp_type()
for idx, single_res in enumerate(res):
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(
idx=idx, all_res="_".join(res), out_arg=single_res
)
)
elif (
isinstance(derivative.var_types[0], ListType)
and derivative.var_types[0].is_tensor_like()
):
assert (
len(derivative.var_types) == 1
), "Expected number of outputs to be 1 if function returns ListType"
if not is_foreach:
opt_res_grad_type = OptionalCType(
VectorCType(BaseCType(tensorT))
).cpp_type()
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
else:
# TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
# Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
# can reach here.
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
else:
raise RuntimeError("Unsupported output type for forward derivative")
if not is_foreach:
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
content.append(
FW_DERIVATIVE_TEMPLATE.substitute(
fw_grad_opt_definition=fw_grad_opt_definition,
requires_fw_grad=requires_fw_grad,
formula=derivative.formula,
out_arg="_".join(res),
unpacked_arguments=unpacked_arguments,
)
)
else:
# note(crcrpar): Assuming `self` is TensorList.
fw_grad_opt_definition = (
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
"(self.size(), c10::nullopt);"
)
foreach_forward_grad_formula = derivative.formula
_foreach_arg: Union[Argument, DifferentiableInput]
if inplace:
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
if not (
is_tensor_type(_foreach_arg.type)
or is_tensor_list_type(_foreach_arg.type)
):
pattern = _foreach_arg.name
if isinstance(_foreach_arg.type, ListType):
pattern += "[i]"
foreach_forward_grad_formula = (
foreach_forward_grad_formula.replace(
_ref_arg.name, pattern
)
)
else:
if (
"result" in foreach_forward_grad_formula
and "result[i]" not in foreach_forward_grad_formula
):
foreach_forward_grad_formula = (
foreach_forward_grad_formula.replace("result", "result[i]")
)
content.append(
FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
fw_grad_opt_definition=fw_grad_opt_definition,
vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts",
any_has_forward_grad_for_current_index=" || ".join(
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
for derivative in fw_derivatives
),
formula=foreach_forward_grad_formula,
unpacked_arguments=unpacked_arguments,
)
)
# Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
content.append("\n".join(fw_grad_setters))
return content
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
#
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
#
if derivative is None:
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
# - Used in the out_fn case when we want to forbid fw derivatives
# - Used in the case where the fw_derivative is not defined, but we want
# To check if there is a decomposition registered for jvp
to_check: List[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
)
):
if is_tensor_type(inp.type):
to_check.append(
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
)
elif is_tensor_list_type(inp.type):
to_check.append(
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
req_inp=inp.name
)
)
else:
raise RuntimeError(
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
)
return f'({" || ".join(to_check)})'
else:
# (2) If derivative is provided, use that information to determine which inputs
# to check fw_grad for
assert derivative.required_inputs_fw_grad is not None
if len(derivative.required_inputs_fw_grad) == 0:
# Handle functions like stack
# For these, we don't unpack anything and always call the user function
if not (
len(differentiable_inputs) == 1
and is_tensor_list_type(differentiable_inputs[0].type)
):
raise RuntimeError(
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
"forward AD formula does not use any input tangent) even though a forward gradient "
"formula has been defined for it. This case should only happen for function that "
"take a single TensorList as input. All other cases are not supported right now."
)
any_has_fw_grad = "true"
else:
any_has_fw_grad = " || ".join(
[
(
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
if is_tensor_list_type(inp.type)
else FW_DERIVATIVE_CHECK_TEMPLATE
).substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]
)
any_has_fw_grad = f"({any_has_fw_grad})"
return any_has_fw_grad
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
if is_out_fn:
msg = "because it is an out= function"
else:
msg = (
"because it has not been implemented yet.\\nPlease file an issue "
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
"so that we can prioritize its implementation."
)
cond = get_any_has_fw_grad_cond(derivative=None)
return (
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
if cond != ""
else ""
)
body: List[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f)
body.extend(unpack_args_stats)
if requires_derivative:
body.extend(emit_any_requires_grad())
body.extend(emit_any_has_forward_grad())
body.extend(emit_check_inplace())
body.extend(emit_original_self_definition())
body.extend(setup_derivative(differentiable_inputs))
body.append(declare_returned_variables(f))
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
if requires_derivative:
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
body.append(emit_history())
body.extend(emit_check_if_in_complex_autograd_allowlist())
if is_out_fn:
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
else:
if requires_derivative and not try_jit_decomposition:
if len(fw_derivatives) > 0:
body.extend(emit_fw_derivatives())
else:
body.append(emit_forbid_fw_derivatives())
if requires_derivative:
# Save only after the forward AD has been set up
body.append(emit_save_outputs())
if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
# `inplace` implies that there is exactly one output named `self`,
# so we can keep the generated code easy. If you need to
# `reset_grad_accumulator` in an operator that's not `inplace`, you can
# remove this assert but the code generation will get more elaborate
assert inplace
body.append("reset_grad_accumulator(self);")
if not returns_void:
body.append(f"return {get_return_value(f)};")
return body