2140 lines
81 KiB
Python
2140 lines
81 KiB
Python
# Generates VariableType.h/cpp
|
|
#
|
|
# **If any changes are being made to the VariableType codegen please also check
|
|
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
|
#
|
|
# VariableType is a subclass of at::Type that provides the binding code
|
|
# necessary to provide a differentiable version of ATen operators. There are a
|
|
# number of different things we could mean:
|
|
#
|
|
# - Given a non-differentiable forward implementation, we might
|
|
# directly associate it with a backward implementation to make
|
|
# it differentiable. This is the common case.
|
|
#
|
|
# - Some functions don't need a backwards implementation, because
|
|
# backpropagation will never propagate beyond them. There are a
|
|
# number of different reasons why this may be the case:
|
|
#
|
|
# - The function has no differentiable inputs
|
|
# - The function's output is not differentiable
|
|
# - The function has no data dependency on its input
|
|
#
|
|
# - Some function don't need a backwards implementation because they
|
|
# are implemented as a composition of other (differentiable) ATen
|
|
# functions. These are dispatched directly to the Type superclass,
|
|
# which will in turn dispatch back to VariableType for its
|
|
# differentiable subcomponents.
|
|
#
|
|
import re
|
|
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
from torchgen.api import cpp
|
|
from torchgen.api.autograd import (
|
|
DifferentiableInput,
|
|
dispatch_strategy,
|
|
ForwardDerivative,
|
|
gen_differentiable_outputs,
|
|
is_differentiable,
|
|
NativeFunctionWithDifferentiabilityInfo,
|
|
SavedAttribute,
|
|
)
|
|
|
|
from torchgen.api.types import (
|
|
ArrayRefCType,
|
|
BaseCppType,
|
|
BaseCType,
|
|
Binding,
|
|
DispatcherSignature,
|
|
intArrayRefT,
|
|
iTensorListRefT,
|
|
ListCType,
|
|
MutRefCType,
|
|
OptionalCType,
|
|
scalarT,
|
|
SpecialArgName,
|
|
stringT,
|
|
symIntArrayRefT,
|
|
TENSOR_LIST_LIKE_CTYPES,
|
|
tensorListT,
|
|
tensorT,
|
|
TupleCType,
|
|
VectorCType,
|
|
)
|
|
from torchgen.code_template import CodeTemplate
|
|
from torchgen.context import (
|
|
native_function_manager,
|
|
with_native_function,
|
|
with_native_function_and,
|
|
)
|
|
from torchgen.model import (
|
|
Argument,
|
|
BaseType,
|
|
ListType,
|
|
NativeFunction,
|
|
SchemaKind,
|
|
SelfArgument,
|
|
TensorOptionsArguments,
|
|
)
|
|
from torchgen.utils import FileManager, mapMaybe
|
|
|
|
from .context import with_native_function_with_differentiability_info_and_key
|
|
from .gen_inplace_or_view_type import (
|
|
ALL_VIEW_FUNCTIONS,
|
|
ASSIGN_RETURN_VALUE,
|
|
AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION,
|
|
gen_formals,
|
|
get_base_name,
|
|
get_view_info,
|
|
is_tensor_list_type,
|
|
is_tensor_type,
|
|
METHOD_DEFINITION,
|
|
modifies_arguments,
|
|
TMP_VAR,
|
|
unpack_args,
|
|
unpacked_name,
|
|
use_derived,
|
|
WRAPPER_REGISTRATION,
|
|
)
|
|
from .gen_trace_type import (
|
|
declare_returned_variables,
|
|
get_return_value,
|
|
MANUAL_AUTOGRAD_AND_TRACER,
|
|
MANUAL_BACKEND,
|
|
tie_return_values,
|
|
type_wrapper_name,
|
|
)
|
|
|
|
# We don't set or modify grad_fn on these methods. Generally, they return
|
|
# tensors that have requires_grad=False. In-place functions listed here will
|
|
# not examine or modify requires_grad or grad_fn.
|
|
# NB: this does NOT include overload name
|
|
DONT_REQUIRE_DERIVATIVE = {
|
|
# These only depend on the input Tensor's shape and device, not the data
|
|
"empty_like",
|
|
"ones_like",
|
|
"full_like",
|
|
"zeros_like",
|
|
"rand_like",
|
|
"randn_like",
|
|
"new_empty",
|
|
"new_empty_strided",
|
|
"new_full",
|
|
"new_zeros",
|
|
"new_ones",
|
|
# These are only implemented on integral types
|
|
"__and__",
|
|
"__iand__",
|
|
"__ilshift__",
|
|
"__ior__",
|
|
"__irshift__",
|
|
"__ixor__",
|
|
"__lshift__",
|
|
"__or__",
|
|
"__rshift__",
|
|
"__xor__",
|
|
# These work on integral data types, and hence don't require derivative
|
|
"_sobol_engine_draw",
|
|
"_sobol_engine_ff",
|
|
"_sobol_engine_scramble_",
|
|
"_sobol_engine_initialize_state_",
|
|
# This is an unsafe method that is meant to be out of reach of autograd.
|
|
"_coalesced_",
|
|
# Quantize functions should not record gradients
|
|
"quantize_per_tensor",
|
|
"quantize_per_channel",
|
|
# Functions that return integers should not have output that require gradients
|
|
"argmax",
|
|
"argmin",
|
|
"argsort",
|
|
"searchsorted",
|
|
"bucketize",
|
|
# Functions that return booleans are not differentiable
|
|
"isnan",
|
|
"isposinf",
|
|
"isneginf",
|
|
"isinf",
|
|
"signbit",
|
|
"isin",
|
|
"allclose",
|
|
# Functions return none are not differentiable
|
|
"record_stream",
|
|
# These functions are not differentiable
|
|
"logical_and",
|
|
"logical_xor",
|
|
"logical_not",
|
|
"logical_or",
|
|
# This function returns nested_tensor shape as a tensor that is non-differentiable
|
|
"_nested_tensor_size",
|
|
"_nested_tensor_strides",
|
|
}
|
|
|
|
# The C -> R functions at the time of adding this are still being audited and tested
|
|
# but will not error out.
|
|
# C -> C, R -> C functions for which backward is correctly implemented and tested
|
|
GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|
"fill",
|
|
"t",
|
|
"view",
|
|
"reshape",
|
|
"reshape_as",
|
|
"view_as",
|
|
"roll",
|
|
"clone",
|
|
"block_diag",
|
|
"diag_embed",
|
|
"repeat",
|
|
"expand",
|
|
"flip",
|
|
"fliplr",
|
|
"flipud",
|
|
"rot90",
|
|
"nanmean",
|
|
"nansum",
|
|
"transpose",
|
|
"permute",
|
|
"squeeze",
|
|
"unsqueeze",
|
|
"resize",
|
|
"resize_as",
|
|
"tril",
|
|
"triu",
|
|
"chunk",
|
|
"zero_",
|
|
"eq_",
|
|
"ne_",
|
|
"add",
|
|
"__radd__",
|
|
"sum",
|
|
"_conj",
|
|
"sin",
|
|
"cos",
|
|
"mul",
|
|
"sinc",
|
|
"sinh",
|
|
"cosh",
|
|
"__rmul__",
|
|
"sgn",
|
|
"asin",
|
|
"acos",
|
|
"sub",
|
|
"div",
|
|
"cat",
|
|
"view_as_complex",
|
|
"index_put",
|
|
"neg",
|
|
"complex",
|
|
"select",
|
|
"where",
|
|
"as_strided",
|
|
"as_strided_scatter",
|
|
"slice",
|
|
"constant_pad_nd",
|
|
"unbind",
|
|
"split",
|
|
"split_with_sizes",
|
|
"unsafe_split",
|
|
"split_with_sizes_backward",
|
|
"dot",
|
|
"vdot",
|
|
"cholesky",
|
|
"triangular_solve",
|
|
"mm",
|
|
"_unsafe_view",
|
|
"mv",
|
|
"outer",
|
|
"bmm",
|
|
"diagonal",
|
|
"alias",
|
|
"atan",
|
|
"log",
|
|
"log10",
|
|
"log1p",
|
|
"log2",
|
|
"logaddexp",
|
|
"logcumsumexp",
|
|
"reciprocal",
|
|
"tan",
|
|
"pow",
|
|
"rsqrt",
|
|
"tanh",
|
|
"tanh_backward",
|
|
"asinh",
|
|
"acosh",
|
|
"atanh",
|
|
"take",
|
|
"fill_",
|
|
"exp",
|
|
"exp2",
|
|
"expm1",
|
|
"nonzero",
|
|
"mean",
|
|
"std_mean",
|
|
"var_mean",
|
|
"inverse",
|
|
"solve",
|
|
"linalg_cholesky",
|
|
"addcmul",
|
|
"addcdiv",
|
|
"matrix_exp",
|
|
"linalg_matrix_exp",
|
|
"_linalg_eigh",
|
|
"cholesky_solve",
|
|
"linalg_qr",
|
|
"_linalg_svd",
|
|
"_fft_c2c",
|
|
"_fft_r2c",
|
|
"linalg_solve",
|
|
"sqrt",
|
|
"stack",
|
|
"gather",
|
|
"index_select",
|
|
"index_add_",
|
|
"linalg_inv",
|
|
"linalg_inv_ex",
|
|
"baddbmm",
|
|
"addbmm",
|
|
"addmm",
|
|
"addmv",
|
|
"addr",
|
|
"linalg_householder_product",
|
|
"ormqr",
|
|
"reflection_pad1d",
|
|
"reflection_pad2d",
|
|
"reflection_pad3d",
|
|
"linalg_cholesky_ex",
|
|
"linalg_eig",
|
|
"diagonal_copy",
|
|
"diagonal_scatter",
|
|
"select_backward",
|
|
"diagonal_backward",
|
|
"slice_backward",
|
|
"reflection_pad1d_backward",
|
|
"reflection_pad2d_backward",
|
|
"reflection_pad3d_backward",
|
|
"_sparse_sparse_matmul",
|
|
"replication_pad1d",
|
|
"replication_pad2d",
|
|
"replication_pad3d",
|
|
"put",
|
|
"put_",
|
|
"_to_copy",
|
|
"replication_pad1d_backward",
|
|
"replication_pad2d_backward",
|
|
"replication_pad3d_backward",
|
|
"diag",
|
|
"masked_scatter",
|
|
"masked_select",
|
|
"index_add",
|
|
"index_fill",
|
|
"trace",
|
|
"polar",
|
|
"cumsum",
|
|
"rsub",
|
|
"eig",
|
|
"lerp",
|
|
"linalg_vector_norm",
|
|
"cumprod",
|
|
"prod",
|
|
"index_copy",
|
|
"lu",
|
|
"unfold",
|
|
"unfold_backward",
|
|
"index",
|
|
"masked_fill",
|
|
"linalg_cross",
|
|
"lu_unpack",
|
|
"renorm",
|
|
"_conj_physical",
|
|
"linalg_lu_factor_ex",
|
|
"scatter",
|
|
"scatter_add",
|
|
"sigmoid",
|
|
"sigmoid_backward",
|
|
"sparse_mask",
|
|
"trapezoid",
|
|
"cumulative_trapezoid",
|
|
"conj_physical_",
|
|
"_neg_view",
|
|
"_reshape_alias",
|
|
"_reshape_copy",
|
|
"_linalg_det",
|
|
"lu_solve",
|
|
"linalg_solve_triangular",
|
|
"linalg_pinv",
|
|
"linalg_lstsq",
|
|
"unfold_copy",
|
|
"col2im",
|
|
"im2col",
|
|
"cholesky_inverse",
|
|
"to_sparse",
|
|
"sparse_sampled_addmm",
|
|
"linalg_lu",
|
|
"pixel_shuffle",
|
|
"pixel_unshuffle",
|
|
"linalg_lu_solve",
|
|
"_linalg_slogdet",
|
|
"_linalg_solve_ex",
|
|
}
|
|
|
|
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
|
"_to_dense",
|
|
"_coalesce",
|
|
"coalesce",
|
|
"values",
|
|
"_sparse_coo_tensor_with_dims_and_tensors",
|
|
"_sparse_addmm",
|
|
}
|
|
|
|
GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
|
|
|
|
# Some operators invalidate the grad_accumulator. Let's reset it.
|
|
RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
|
|
|
|
# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
|
#
|
|
# We check the following properties:
|
|
# 1) A function should never change the input tensors' underlying c10::TensorImpl
|
|
# pointers or c10::Storage pointers, even if it modifies its input tensors (via
|
|
# inplace or out-variants)
|
|
# If the function does not modify its arguments, we also check the following properties
|
|
# pertaining to its output:
|
|
# 2) Its TensorImpl has use_count of 1
|
|
# 3) If the function is a view function, it has the same StorageImpl as that of
|
|
# the input it is aliased with. Otherwise, its StorageImpl has use_count of 1
|
|
#
|
|
# The following code templates implement the checks for this invariant:
|
|
SAVE_TENSOR_STORAGE = CodeTemplate(
|
|
"""\
|
|
c10::optional<Storage> ${tensor_name}_storage_saved =
|
|
${tensor_name}.has_storage() ? c10::optional<Storage>(${tensor_name}.storage()) : c10::nullopt;
|
|
"""
|
|
)
|
|
|
|
|
|
# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2)
|
|
ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
|
|
"""\
|
|
if (${tensor_name}_storage_saved.has_value() &&
|
|
!at::impl::dispatch_mode_enabled() &&
|
|
!at::impl::tensor_has_dispatch(${tensor_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
|
|
"""
|
|
)
|
|
|
|
SAVE_TENSORLIST_STORAGE = CodeTemplate(
|
|
"""\
|
|
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
|
|
for (const Tensor& tensor : ${tensorlist_name})
|
|
${tensorlist_name}_storage_saved.push_back(
|
|
tensor.has_storage() ? c10::optional<Storage>(tensor.storage()) : c10::nullopt);
|
|
"""
|
|
)
|
|
|
|
ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
|
|
"""\
|
|
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
|
|
if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
|
|
}
|
|
"""
|
|
)
|
|
|
|
SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
|
|
"""\
|
|
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
|
|
for (const c10::optional<Tensor>& tensor : ${tensorlist_name})
|
|
${tensorlist_name}_storage_saved.push_back(
|
|
tensor.has_value() && tensor->has_storage() ? c10::optional<Storage>(tensor->storage()) : c10::nullopt);
|
|
"""
|
|
)
|
|
|
|
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
|
|
"""\
|
|
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
|
|
if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
|
|
static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->storage()));
|
|
}
|
|
"""
|
|
)
|
|
|
|
SAVE_TENSOR_IMPL = CodeTemplate(
|
|
"""\
|
|
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
|
|
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
|
|
"""
|
|
)
|
|
|
|
ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
|
|
"""\
|
|
if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
|
|
"""
|
|
)
|
|
|
|
ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate(
|
|
"""\
|
|
if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}");
|
|
"""
|
|
)
|
|
|
|
ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate(
|
|
"""\
|
|
if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) {
|
|
TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}");
|
|
}
|
|
"""
|
|
)
|
|
|
|
SAVE_TENSORLIST_IMPL = CodeTemplate(
|
|
"""\
|
|
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
|
|
for (size_t i=0; i<${tensorlist_name}.size(); i++)
|
|
if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
|
|
"""
|
|
)
|
|
|
|
ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
|
|
"""\
|
|
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
|
|
if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
|
|
TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
|
|
}
|
|
"""
|
|
)
|
|
|
|
SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate(
|
|
"""\
|
|
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
|
|
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
|
c10::optional<Tensor> t = ${tensorlist_name}[i];
|
|
if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
|
|
}
|
|
"""
|
|
)
|
|
|
|
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate(
|
|
"""\
|
|
for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
|
|
if (${tensorlist_name}_impl_saved[i])
|
|
TORCH_INTERNAL_ASSERT(
|
|
${tensorlist_name}_impl_saved[i] == static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
|
|
}
|
|
"""
|
|
)
|
|
|
|
# The following list contains functions that we don't enforce the invariant on.
|
|
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
|
|
# These functions are expected to change impl or storage of input tensors
|
|
"set_",
|
|
"_cudnn_rnn_flatten_weight",
|
|
}
|
|
DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
|
|
# These non-inplace, non-out functions return tensors with use_count > 1
|
|
# Therefore, they MAY (but not necessarily) return one of its inputs as-is
|
|
# See https://github.com/pytorch/pytorch/issues/60426 for more information
|
|
"_embedding_bag",
|
|
"_embedding_bag_forward_only",
|
|
"q_per_channel_scales",
|
|
"q_per_channel_zero_points",
|
|
"lu_unpack",
|
|
"_cudnn_rnn_backward",
|
|
# The below failed StorageImpl use_count check but we skip tensor_impl check
|
|
# just in case
|
|
"_cudnn_rnn",
|
|
"dequantize_self",
|
|
# lift() should never actually be called with a requires_grad=True tensor,
|
|
"lift",
|
|
"lift_fresh",
|
|
"lift_fresh_copy",
|
|
# Nested Tensors related functions
|
|
# _nested_tensor_size() should never actually be called with requires_grad=True tensor
|
|
"_nested_tensor_size",
|
|
"_nested_tensor_strides",
|
|
"_nested_tensor_storage_offsets",
|
|
}
|
|
|
|
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
|
|
# These non-view functions return tensors with storage use_count != 1
|
|
"_slow_conv2d_forward",
|
|
"slow_conv3d_forward",
|
|
"channel_shuffle",
|
|
# If an input is returned as-is in output, we cannot guarantee its storage_impl
|
|
# use count to be 1 either.
|
|
*DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
|
|
}
|
|
# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ]
|
|
|
|
DECLARE_GRAD_FN = CodeTemplate(
|
|
"""\
|
|
std::shared_ptr<${op}> grad_fn;
|
|
"""
|
|
)
|
|
|
|
DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate(
|
|
"""\
|
|
std::vector<std::shared_ptr<${op}>> grad_fns;
|
|
"""
|
|
)
|
|
|
|
SETUP_ANY_REQUIRES_GRAD = CodeTemplate(
|
|
"""\
|
|
[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} );
|
|
${extra_differentiability_conditions}
|
|
"""
|
|
)
|
|
|
|
SETUP_DERIVATIVE = CodeTemplate(
|
|
"""\
|
|
if (_any_requires_grad) {
|
|
${setup}
|
|
}
|
|
"""
|
|
)
|
|
|
|
SETUP_NONE_REQUIRES_GRAD = CodeTemplate(
|
|
"""\
|
|
if (compute_requires_grad( ${args_to_check} )) {
|
|
throw_error_out_requires_grad("${base_name}");
|
|
}
|
|
"""
|
|
)
|
|
|
|
ASSIGN_GRAD_FN = CodeTemplate(
|
|
"""\
|
|
grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
|
|
grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
|
|
"""
|
|
)
|
|
|
|
# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i`
|
|
# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars.
|
|
ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate(
|
|
"""\
|
|
for (const auto& i : c10::irange( ${irange} )) {
|
|
const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives});
|
|
check_inplace(self[i], ith_requires_grad);
|
|
grad_fns.push_back([&]() -> std::shared_ptr<${op}> {
|
|
if (!ith_requires_grad) {
|
|
return nullptr;
|
|
} else {
|
|
auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
|
|
grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
|
|
return grad_fn;
|
|
}
|
|
}());
|
|
}
|
|
"""
|
|
)
|
|
|
|
CALL_REDISPATCH = CodeTemplate(
|
|
"""\
|
|
at::redispatch::${api_name}(${unpacked_args})"""
|
|
)
|
|
# If the non-variable operation has return values, we use the `tmp` variable to hold the
|
|
# values temporarily and pass the values to the return variables outside of the
|
|
# `at::AutoDispatchBelowAutograd` guard block.
|
|
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate(
|
|
"""\
|
|
auto ${tmp_var} = ([&]() {
|
|
if (${any_has_forward_grad}) {
|
|
static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
|
|
static c10::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
|
|
return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
|
|
} else {
|
|
${guard}
|
|
return ${base_type_call};
|
|
}
|
|
})();
|
|
"""
|
|
)
|
|
|
|
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
|
|
"""\
|
|
auto ${tmp_var} = ([&]() {
|
|
${guard}
|
|
return ${base_type_call};
|
|
})();
|
|
"""
|
|
)
|
|
|
|
DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
|
|
"""\
|
|
{
|
|
${guard}
|
|
${base_type_call};
|
|
}
|
|
"""
|
|
)
|
|
|
|
SET_HISTORY = CodeTemplate(
|
|
"""\
|
|
if (grad_fn) {
|
|
${fn}_history(${differentiable_outputs}, grad_fn);
|
|
}
|
|
"""
|
|
)
|
|
|
|
LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate(
|
|
"""\
|
|
if (!grad_fns.empty()) {
|
|
${preamble}
|
|
for (const auto& i : c10::irange(grad_fns.size())) {
|
|
auto grad_fn = grad_fns[i];
|
|
if (grad_fn != nullptr) {
|
|
${statements}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
CONDITIONAL = CodeTemplate(
|
|
"""\
|
|
if (${cond}) {
|
|
${statements}
|
|
}
|
|
"""
|
|
)
|
|
|
|
RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
|
|
"""\
|
|
#ifndef NDEBUG
|
|
${statements}
|
|
#endif
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
isFwGradDefined(${req_inp})\
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
isFwGradDefinedTensorList(${req_inp})\
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
auto ${inp_name}_t_raw = toNonOptFwGrad(${inp});
|
|
auto ${inp_name}_tensor = toNonOptTensor(${inp});
|
|
auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined())
|
|
? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sizes(), ${inp_name}_tensor.options());
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
auto ${inp_name}_p = toNonOptPrimal(${inp});
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate(
|
|
"""\
|
|
if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) {
|
|
// The hardcoded 0 here will need to be updated once we support multiple levels.
|
|
${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate(
|
|
"""\
|
|
for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) {
|
|
auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i];
|
|
if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) {
|
|
// The hardcoded 0 here will need to be updated once we support multiple levels.
|
|
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate(
|
|
"""\
|
|
if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()
|
|
&& ${out_arg}.defined()) {
|
|
${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false);
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
|
|
"""\
|
|
if (${out_arg}_new_fw_grad_opt.has_value()) {
|
|
auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
|
|
TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
|
|
for (const auto i : c10::irange(${out_arg}.size())) {
|
|
if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
|
|
// The hardcoded 0 here will need to be updated once we support multiple levels.
|
|
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
${fw_grad_opt_definition}
|
|
if (${requires_fw_grad}) {
|
|
${unpacked_arguments}
|
|
${out_arg}_new_fw_grad_opt = ${formula};
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
${fw_grad_opt_definition}
|
|
for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) {
|
|
if (${any_has_forward_grad_for_current_index}) {
|
|
${unpacked_arguments}
|
|
${vector_of_optional_tensor}[i] = ${formula};
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
|
|
"""
|
|
)
|
|
|
|
FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate(
|
|
"""\
|
|
for (const auto& _t: ${arg}) {
|
|
TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
|
|
}
|
|
"""
|
|
)
|
|
|
|
|
|
def gen_variable_type(
|
|
out: str,
|
|
native_yaml_path: str,
|
|
tags_yaml_path: str,
|
|
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
|
template_path: str,
|
|
used_keys: Set[str],
|
|
) -> None:
|
|
"""VariableType.h and VariableType.cpp body
|
|
|
|
This is the at::Type subclass for differentiable tensors. The
|
|
implementation of each function dispatches to the base tensor type to
|
|
compute the output. The grad_fn is attached to differentiable functions.
|
|
"""
|
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
|
fm.write(
|
|
"VariableType.h",
|
|
lambda: {
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir_for_comments()}/VariableType.h"
|
|
},
|
|
)
|
|
|
|
# helper that generates a TORCH_LIBRARY_IMPL macro for each
|
|
# dispatch key that appears in derivatives.yaml
|
|
def wrapper_registrations(used_keys: Set[str]) -> str:
|
|
library_impl_macro_list: List[str] = []
|
|
for key in sorted(used_keys):
|
|
dispatch_key = key
|
|
if key == "Default":
|
|
dispatch_key = "Autograd"
|
|
library_impl_macro = (
|
|
f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) "
|
|
+ "{\n"
|
|
+ "${"
|
|
+ f"wrapper_registrations_{key}"
|
|
+ "}\n}"
|
|
)
|
|
library_impl_macro_list += [library_impl_macro]
|
|
return "\n\n".join(library_impl_macro_list)
|
|
|
|
# Generate a new template from VariableType.cpp which replaces ${wrapper_registrations}
|
|
# with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml
|
|
fm1 = FileManager(
|
|
install_dir=out + "/templates", template_dir=template_path, dry_run=False
|
|
)
|
|
fm1.write(
|
|
"VariableType.cpp",
|
|
lambda: {
|
|
"type_derived_method_definitions": "\n\n".join(
|
|
[
|
|
"${" + f"type_derived_method_definitions_{key}" + "}"
|
|
for key in sorted(used_keys)
|
|
]
|
|
),
|
|
"wrapper_registrations": wrapper_registrations(used_keys),
|
|
},
|
|
)
|
|
|
|
# Generate final VariableType_*.cpp files from the generated template
|
|
fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)
|
|
|
|
sharded_keys = set(
|
|
[f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
|
|
+ [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
|
|
)
|
|
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
|
# template regarding sharding of the generated files.
|
|
fm2.write_sharded(
|
|
"VariableType.cpp",
|
|
[fn for fn in fns_with_diff_infos if use_derived(fn)],
|
|
key_fn=lambda fn: cpp.name(fn.func.func),
|
|
base_env={
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir_for_comments()}/VariableType.cpp",
|
|
},
|
|
env_callable=gen_variable_type_func,
|
|
num_shards=5,
|
|
sharded_keys=sharded_keys,
|
|
)
|
|
|
|
|
|
@with_native_function_and
|
|
def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
|
|
return WRAPPER_REGISTRATION.substitute(
|
|
unqual_operator_name_with_overload=f.func.name,
|
|
type_wrapper_name=type_wrapper_name(f, key),
|
|
class_type="VariableType",
|
|
)
|
|
|
|
|
|
def gen_variable_type_func(
|
|
fn: NativeFunctionWithDifferentiabilityInfo,
|
|
) -> Dict[str, List[str]]:
|
|
f = fn.func
|
|
result = {}
|
|
with native_function_manager(f):
|
|
name = cpp.name(f.func)
|
|
formals = gen_formals(f)
|
|
|
|
if (
|
|
fn.info is None
|
|
and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR
|
|
and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE
|
|
and len(gen_differentiable_outputs(fn)) > 0
|
|
and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
|
|
and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
|
|
and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
|
|
):
|
|
# NOTE: [ Registering AutogradNotImplemented boxed kernel ]
|
|
#
|
|
# When there is no derivatives.yaml entry, we register a generic boxed
|
|
# NotImplemented kernel to set grad_fn to be NotImplemented, so that forward
|
|
# proceeds as usual but an error is properly produced on backward.
|
|
# TODO: it would be nice to not have these special cases
|
|
#
|
|
# There are several cases where still let codegen handle it:
|
|
# 1) ops that need to reset grad accumulator (we let codegen handle this case
|
|
# because) the list is (currently) only accessible in Python.
|
|
# 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes
|
|
# autograd a fallthrough with NDEBUG checks. This can be useful for when all
|
|
# outputs are integral.
|
|
# 3) When there are no differentiable outputs. This is similar to (2).
|
|
# 4) There are certain ops where we skip certain NDEBUG checks. this is similar
|
|
# to (1).
|
|
type_definition = ""
|
|
wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute(
|
|
unqual_operator_name_with_overload=f.func.name
|
|
)
|
|
result["type_derived_method_definitions_Default"] = [type_definition]
|
|
result["wrapper_registrations_Default"] = [wrapper_registration]
|
|
else:
|
|
if not fn.info:
|
|
key = "Default"
|
|
type_definition = METHOD_DEFINITION.substitute(
|
|
return_type=cpp.returns_type(
|
|
f.func.returns, symint=True
|
|
).cpp_type(),
|
|
type_wrapper_name=type_wrapper_name(f, key),
|
|
type_definition_body=emit_body(fn, key),
|
|
formals=formals,
|
|
)
|
|
wrapper_registration = gen_wrapper_registration(f, key)
|
|
result[f"type_derived_method_definitions_{key}"] = [type_definition]
|
|
result[f"wrapper_registrations_{key}"] = [wrapper_registration]
|
|
else:
|
|
for key in fn.info.keys():
|
|
type_definition = METHOD_DEFINITION.substitute(
|
|
return_type=cpp.returns_type(
|
|
f.func.returns, symint=True
|
|
).cpp_type(),
|
|
type_wrapper_name=type_wrapper_name(f, key),
|
|
type_definition_body=emit_body(fn, key),
|
|
formals=formals,
|
|
)
|
|
wrapper_registration = gen_wrapper_registration(f, key)
|
|
result[f"type_derived_method_definitions_{key}"] = [type_definition]
|
|
result[f"wrapper_registrations_{key}"] = [wrapper_registration]
|
|
# See Note [Manual Backend kernels]
|
|
assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
|
|
# If you want to register a kernel to Autograd, you must make the op abstract.
|
|
# In other words, this op must have dispatch section in native_functions.yaml.
|
|
if name in MANUAL_AUTOGRAD_AND_TRACER or (
|
|
fn.info and any(info.has_derivatives for info in fn.info.values())
|
|
):
|
|
msg = (
|
|
f"There's a formula for {name}(or its functional variant) in derivatives.yaml. "
|
|
f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA "
|
|
f"or CompositeExplicitAutograd in native_functions.yaml. Please see "
|
|
f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword "
|
|
f"for instructions to choose the right dispatch keyword."
|
|
)
|
|
assert f.is_abstract, msg
|
|
|
|
return result
|
|
|
|
|
|
_foreach_ops_without_differentiability_info = {
|
|
# No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`.
|
|
("_foreach_maximum", "Scalar"),
|
|
("_foreach_maximum", "ScalarList"),
|
|
("_foreach_minimum", "Scalar"),
|
|
("_foreach_minimum", "ScalarList"),
|
|
# No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor.
|
|
("_foreach_addcdiv", "Tensor"),
|
|
("_foreach_addcmul", "Tensor"),
|
|
("_foreach_copy", ""),
|
|
}
|
|
|
|
_foreach_ops_with_different_arity = {
|
|
# These ops lack `alpha` of scaling factor to applied to the right hand side argument.
|
|
("_foreach_add", "Scalar"),
|
|
("_foreach_add", "ScalarList"),
|
|
("_foreach_sub", "Scalar"),
|
|
("_foreach_sub", "ScalarList"),
|
|
}
|
|
|
|
|
|
@with_native_function_with_differentiability_info_and_key
|
|
def emit_body(
|
|
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
|
) -> List[str]:
|
|
assert dispatch_strategy(fn) == "use_derived"
|
|
f = fn.func
|
|
info = fn.info[key] if fn.info else None
|
|
fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else []
|
|
|
|
name = cpp.name(f.func)
|
|
inplace = f.func.kind() == SchemaKind.inplace
|
|
is_out_fn = f.func.kind() == SchemaKind.out
|
|
returns_void = len(f.func.returns) == 0
|
|
base_name = get_base_name(f)
|
|
view_info = get_view_info(f)
|
|
|
|
is_foreach = name.startswith("_foreach")
|
|
is_inplace_foreach = is_foreach and inplace
|
|
if is_inplace_foreach:
|
|
inplace_foreacharg2refarg: Dict[Argument, Argument] = {}
|
|
refargname2inplace_foreacharg: Dict[str, Argument] = {}
|
|
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
|
|
if info is None:
|
|
assert (
|
|
base_name_and_overload_name
|
|
in _foreach_ops_without_differentiability_info
|
|
), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
|
|
else:
|
|
assert (
|
|
len(f.func.arguments.flat_non_out)
|
|
== len(info.func.func.arguments.flat_non_out)
|
|
) or (base_name_and_overload_name in _foreach_ops_with_different_arity), (
|
|
f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args "
|
|
f"but the reference has {len(info.func.func.arguments.flat_non_out)}"
|
|
)
|
|
for foreach_arg, ref_arg in zip(
|
|
f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out
|
|
):
|
|
foreach_arg_type = foreach_arg.type
|
|
if isinstance(foreach_arg_type, ListType):
|
|
foreach_arg_type = foreach_arg_type.elem
|
|
assert foreach_arg_type == ref_arg.type
|
|
inplace_foreacharg2refarg[foreach_arg] = ref_arg
|
|
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
|
|
|
|
def gen_differentiable_input(
|
|
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
|
|
) -> Optional[DifferentiableInput]:
|
|
if isinstance(arg, TensorOptionsArguments):
|
|
return None
|
|
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
|
|
|
|
# TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
|
|
# NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
|
|
# not handled properly as they are irrelevant for this codegen.
|
|
cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
|
|
|
|
if not is_differentiable(a.name, a.type, info):
|
|
return None
|
|
return DifferentiableInput(
|
|
name=a.name,
|
|
type=a.type,
|
|
cpp_type=cpp_type,
|
|
)
|
|
|
|
@with_native_function
|
|
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
|
|
arguments = list(f.func.arguments.non_out)
|
|
if is_inplace_foreach and info is not None:
|
|
for i, arg in enumerate(f.func.arguments.flat_non_out):
|
|
if arg in inplace_foreacharg2refarg:
|
|
# note(crcrpar): From what I understand, what matters is only the name.
|
|
# Thus originally I only replace argument only when the names are different.
|
|
# TODO(crcrpar): Make it simpler.
|
|
mapped_arg = inplace_foreacharg2refarg[arg]
|
|
arguments[i] = Argument(
|
|
mapped_arg.name,
|
|
mapped_arg.type,
|
|
mapped_arg.default,
|
|
mapped_arg.annotation,
|
|
)
|
|
return list(mapMaybe(gen_differentiable_input, arguments))
|
|
|
|
def find_args_with_derivatives(
|
|
differentiable_inputs: List[DifferentiableInput],
|
|
) -> List[DifferentiableInput]:
|
|
"""Find arguments that have derivative definitions"""
|
|
if info is None or not info.has_derivatives:
|
|
return differentiable_inputs
|
|
names = {name for d in info.derivatives for name in d.var_names}
|
|
differentiable = [arg for arg in differentiable_inputs if arg.name in names]
|
|
if len(differentiable) != len(names):
|
|
missing = names - {arg.name for arg in differentiable}
|
|
raise RuntimeError(
|
|
f"Missing arguments for derivatives: {missing} in {info.name}"
|
|
)
|
|
return differentiable
|
|
|
|
differentiable_inputs = gen_differentiable_inputs(f)
|
|
args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
|
|
differentiable_outputs = gen_differentiable_outputs(fn, key)
|
|
|
|
undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
|
|
name in DONT_REQUIRE_DERIVATIVE
|
|
)
|
|
|
|
requires_derivative = (
|
|
(not undifferentiable)
|
|
and (len(differentiable_inputs) > 0)
|
|
and (
|
|
(len(differentiable_outputs) > 0)
|
|
# note(crcrpar): In-place foreach functions are a void function.
|
|
or is_inplace_foreach
|
|
)
|
|
)
|
|
|
|
if (
|
|
info is not None
|
|
and info.has_derivatives
|
|
and not requires_derivative
|
|
# out= ops are allowed to have zero returns which cause requires_derivative to be False
|
|
# we shouldn't error out though (out= ops for autograd just redispatch)
|
|
and len(f.func.returns) > 0
|
|
):
|
|
raise RuntimeError(
|
|
f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
|
|
)
|
|
|
|
# note(crcrpar): In-place foreach functions do not support forward AD
|
|
if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach:
|
|
assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
|
|
differentiable_outputs
|
|
), (
|
|
"Expected the number of forward derivatives implemented to match the "
|
|
"number of differentiable outputs. NB: This only applies when at least "
|
|
"one forward derivative is implemented. Not implementing any forward "
|
|
"derivatives is also okay, and we would require inputs to the op to "
|
|
"not have associated tangents in that case."
|
|
)
|
|
|
|
try_jit_decomposition = (
|
|
requires_derivative
|
|
and len(fw_derivatives) == 0
|
|
and (not modifies_arguments(f))
|
|
and (not returns_void)
|
|
)
|
|
|
|
def emit_save_inputs() -> List[str]:
|
|
setup: List[str] = []
|
|
if info is None or not info.has_derivatives:
|
|
return setup
|
|
|
|
has_tensorlist_arg = any(
|
|
is_tensor_list_type(arg.type) for arg in args_with_derivatives
|
|
)
|
|
|
|
# We don't want to save tensors if we know that they will never be used
|
|
# when computing the derivative, so we add guards to those statements
|
|
def guard_for(arg: SavedAttribute) -> Optional[str]:
|
|
assert info is not None
|
|
|
|
# It's hard to determine the edge offset if we have TensorLists
|
|
# NOTE(crcrpar): in-place foreach functions' arguments include tensorlist
|
|
# but their derivatives don't use it, so let them bypass this check.
|
|
if has_tensorlist_arg and (not is_inplace_foreach):
|
|
return None
|
|
|
|
# Empirical evaluation of the cases where we insert those guards in
|
|
# backward show that they are somewhat useless. E.g. there's no need
|
|
# to guard on some values captured from forward, because they had to
|
|
# require_grad if the backward function even gets executed. I don't
|
|
# have any good ideas for detecting those cases, so I simply disabled the
|
|
# checks.
|
|
if "backward" in info.name:
|
|
return None
|
|
|
|
# If there's a single derivative we could compute, we already have
|
|
# a requires_grad check that is sufficient
|
|
if len(args_with_derivatives) <= 1:
|
|
return None
|
|
|
|
# We really only care about trimming down the amount of tensors we save
|
|
if arg.nctype.type != BaseCType(tensorT):
|
|
return None
|
|
|
|
# We want to emit simple guards, so we only allow that if checking one
|
|
# input is enough to determine whether we need that value
|
|
used_in = [d for d in info.derivatives if arg in d.saved_inputs]
|
|
assert len(used_in) > 0
|
|
if len(used_in) != 1:
|
|
return None
|
|
derivative = used_in[0]
|
|
|
|
# Case with multioutput formulas
|
|
# TODO: process all derivative formulas!!!
|
|
if len(derivative.var_names) != 1:
|
|
wrap_opt_if_start = derivative.formula.find(
|
|
f"wrap_opt_if({arg.nctype.name}"
|
|
)
|
|
if wrap_opt_if_start == -1:
|
|
return None
|
|
|
|
wrap_opt_if_match = re.match(
|
|
rf"wrap_opt_if\({arg.nctype.name},(.*?)\)",
|
|
derivative.formula[wrap_opt_if_start:],
|
|
)
|
|
assert wrap_opt_if_match is not None
|
|
|
|
# Condition is between 'wrap_opt_if(var_name,' and ')'.
|
|
condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1)
|
|
wrap_opt_if_condition = wrap_opt_if_match.group(0)[
|
|
condition_slice
|
|
].strip()
|
|
# replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)'
|
|
wrap_opt_if_condition = re.sub(
|
|
r"grad_input_mask\[(\d+)\]",
|
|
r"grad_fn->should_compute_output(\1)",
|
|
wrap_opt_if_condition,
|
|
)
|
|
return f"{wrap_opt_if_condition}"
|
|
|
|
# Figure out the offset of the edge that uses this variable
|
|
derivative_var_name = derivative.var_names[0]
|
|
for edge_off, a in enumerate(args_with_derivatives):
|
|
if a.name == derivative_var_name:
|
|
break
|
|
else:
|
|
raise AssertionError()
|
|
return f"grad_fn->should_compute_output({edge_off})"
|
|
|
|
if is_inplace_foreach:
|
|
save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for)
|
|
if save_input_stmts:
|
|
setup.append(
|
|
LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
|
|
preamble="", statements=save_input_stmts
|
|
)
|
|
)
|
|
else:
|
|
setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
|
|
for arg in args_with_derivatives:
|
|
if is_tensor_list_type(arg.type):
|
|
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
|
|
return setup
|
|
|
|
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
|
|
body: List[str] = []
|
|
if is_out_fn:
|
|
# For out functions, ensure that no input or output requires grad
|
|
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
|
|
body.append(
|
|
SETUP_NONE_REQUIRES_GRAD.substitute(
|
|
base_name=base_name,
|
|
args_to_check=[arg.name for arg in differentiable_inputs],
|
|
)
|
|
)
|
|
body.append(
|
|
SETUP_NONE_REQUIRES_GRAD.substitute(
|
|
base_name=base_name,
|
|
args_to_check=[arg.name for arg in differentiable_outputs],
|
|
)
|
|
)
|
|
return body
|
|
|
|
op = info.op if info is not None and info.has_derivatives else "NotImplemented"
|
|
setup = []
|
|
if not is_inplace_foreach:
|
|
setup.extend(
|
|
ASSIGN_GRAD_FN.substitute(
|
|
op=op,
|
|
op_ctor=""
|
|
if info is not None and info.has_derivatives
|
|
else f'"{cpp.name(f.func)}"',
|
|
args_with_derivatives=[arg.name for arg in args_with_derivatives],
|
|
).split("\n")
|
|
)
|
|
else:
|
|
# note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList.
|
|
list_like_arg = "self"
|
|
args = [arg.name for arg in args_with_derivatives]
|
|
for i, arg in enumerate(args):
|
|
if is_inplace_foreach and info is not None:
|
|
if arg in refargname2inplace_foreacharg:
|
|
foreach_arg = refargname2inplace_foreacharg[arg]
|
|
args[i] = foreach_arg.name + (
|
|
"[i]" if isinstance(foreach_arg.type, ListType) else ""
|
|
)
|
|
else:
|
|
if arg == list_like_arg:
|
|
args[i] = arg + "[i]"
|
|
setup.extend(
|
|
ASSIGN_VECTOR_OF_GRAD_FN.substitute(
|
|
op=op,
|
|
op_ctor=""
|
|
if info is not None and info.has_derivatives
|
|
else f'"{cpp.name(f.func)}"',
|
|
args_with_derivatives=args,
|
|
irange=f"{list_like_arg}.size()",
|
|
).split("\n")
|
|
)
|
|
setup.extend(emit_save_inputs())
|
|
|
|
body.extend(
|
|
emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)
|
|
)
|
|
declare_grad_fn_template = (
|
|
DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN
|
|
)
|
|
body.append(declare_grad_fn_template.substitute(op=op))
|
|
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
|
|
return body
|
|
|
|
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
|
|
body: List[str] = []
|
|
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
|
|
return body
|
|
for arg in differentiable_outputs:
|
|
name = arg.name
|
|
# TODO: should be `arg.type.is_tensor_like()`?
|
|
if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES:
|
|
body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
|
|
return body
|
|
|
|
def emit_check_no_requires_grad(
|
|
tensor_args: List[DifferentiableInput],
|
|
args_with_derivatives: List[DifferentiableInput],
|
|
) -> List[str]:
|
|
"""Checks that arguments without derivatives don't require grad"""
|
|
body: List[str] = []
|
|
for arg in tensor_args:
|
|
if arg in args_with_derivatives:
|
|
continue
|
|
arg_name = arg.name
|
|
if info and arg_name in info.non_differentiable_arg_names:
|
|
continue
|
|
if arg_name == "output":
|
|
# Double-backwards definitions sometimes take in 'input' and
|
|
# 'output', but only define the derivative for input.
|
|
continue
|
|
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
|
|
return body
|
|
|
|
def emit_original_self_definition() -> List[str]:
|
|
body: List[str] = []
|
|
if inplace:
|
|
if is_inplace_foreach:
|
|
body.append(
|
|
"std::vector<c10::optional<at::Tensor>> original_selfs(self.size());"
|
|
)
|
|
else:
|
|
body.append("c10::optional<at::Tensor> original_self;")
|
|
|
|
all_forward_grad_cond = []
|
|
for derivative in fw_derivatives:
|
|
if derivative.required_original_self_value:
|
|
all_forward_grad_cond.append(
|
|
get_any_has_forward_grad_name(derivative.var_names)
|
|
)
|
|
|
|
if all_forward_grad_cond:
|
|
if not is_inplace_foreach:
|
|
body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{')
|
|
body.append(" original_self = self.clone();")
|
|
body.append("}")
|
|
else:
|
|
current_all_forward_grad_cond = [
|
|
f"{cond}[i]" for cond in all_forward_grad_cond
|
|
]
|
|
body.append("for (const auto& i : c10::irange(self.size())) {")
|
|
body.append(
|
|
f" if ({' || '.join(current_all_forward_grad_cond)}) {{"
|
|
)
|
|
body.append(" original_selfs[i] = self[i].clone();")
|
|
body.append(" }")
|
|
body.append("}")
|
|
|
|
return body
|
|
|
|
def save_variables(
|
|
saved_variables: Sequence[SavedAttribute],
|
|
is_output: bool,
|
|
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None,
|
|
) -> Sequence[str]:
|
|
# assign the saved variables to the generated grad_fn
|
|
stmts: List[str] = []
|
|
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
|
|
name = (
|
|
arg.nctype.name.name
|
|
if isinstance(arg.nctype.name, SpecialArgName)
|
|
else arg.nctype.name
|
|
)
|
|
foreacharg: Optional[Argument] = None
|
|
is_foreacharg_list_type: bool = False
|
|
type = arg.nctype.type
|
|
expr = arg.expr
|
|
stmts_prepend = None
|
|
if is_inplace_foreach and info is not None:
|
|
# todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`.
|
|
# for now the example assert would fail.
|
|
name_to_query = name.split("_scalar_type")[0]
|
|
if name_to_query in refargname2inplace_foreacharg:
|
|
foreacharg = refargname2inplace_foreacharg[name_to_query]
|
|
is_foreacharg_list_type = isinstance(foreacharg.type, ListType)
|
|
if foreacharg is not None:
|
|
name_in_expr = (
|
|
f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}"
|
|
)
|
|
src_name = name
|
|
if "_scalar_type" in src_name:
|
|
split_src_name = src_name.split("_scalar_type")
|
|
assert len(split_src_name) == 2
|
|
src_name = split_src_name[0]
|
|
expr = expr.replace(src_name, name_in_expr)
|
|
if (
|
|
type == BaseCType(tensorT)
|
|
or type == OptionalCType(BaseCType(tensorT))
|
|
or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
|
|
or (is_output and type == BaseCType(scalarT))
|
|
):
|
|
# note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored.
|
|
var = name
|
|
name += "_"
|
|
if var == "self" and inplace:
|
|
original_self_var = (
|
|
"original_self"
|
|
if not is_inplace_foreach
|
|
else "original_selfs[i]"
|
|
)
|
|
self_var = var if not is_inplace_foreach else var + "[i]"
|
|
stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()"
|
|
var = f"{original_self_var}.value()"
|
|
assert not is_output
|
|
if inplace and is_output:
|
|
assert name == "result_"
|
|
var = (
|
|
"self[i]"
|
|
if is_inplace_foreach or is_foreacharg_list_type
|
|
else "self"
|
|
)
|
|
is_inplace_view = f"{var}.is_view()"
|
|
expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})"
|
|
else:
|
|
expr = f"SavedVariable({var}, {str(is_output).lower()})"
|
|
if foreacharg is not None and "original_selfs" not in expr:
|
|
expr = expr.replace(src_name, name_in_expr)
|
|
elif (
|
|
type == BaseCType(tensorListT)
|
|
or type == ListCType(OptionalCType(BaseCType(tensorT)))
|
|
or type == BaseCType(iTensorListRefT)
|
|
or type == VectorCType(BaseCType(tensorT))
|
|
):
|
|
# See Note [nuanced return type of out-of-place foreach functions]
|
|
if type == VectorCType(BaseCType(tensorT)):
|
|
assert is_foreach and is_output
|
|
expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})"
|
|
name += "_"
|
|
elif type == BaseCType(intArrayRefT):
|
|
expr = expr + ".vec()"
|
|
elif type == BaseCType(symIntArrayRefT):
|
|
expr = expr + ".vec()"
|
|
elif type == BaseCType(stringT):
|
|
expr = f"std::string({expr})"
|
|
elif type == OptionalCType(BaseCType(stringT)):
|
|
expr = f"{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt"
|
|
elif type == ArrayRefCType(
|
|
elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
|
|
):
|
|
expr = expr + ".vec()"
|
|
|
|
guard = guard_for(arg)
|
|
if guard is None:
|
|
if stmts_prepend:
|
|
stmts.append(f"{stmts_prepend};")
|
|
stmts.append(f"grad_fn->{name} = {expr};")
|
|
else:
|
|
stmts.append(f"if ({guard}) {{")
|
|
if stmts_prepend:
|
|
stmts.append(f" {stmts_prepend};")
|
|
stmts.append(f" grad_fn->{name} = {expr};")
|
|
stmts.append("}")
|
|
return stmts
|
|
|
|
# Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
|
|
# - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
|
|
# - redispatch() avoids a redundant call to RecordFunction, which was already called right before
|
|
# we entered this autograd kernel.
|
|
def emit_dispatch_call(
|
|
f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
|
|
) -> str:
|
|
"""Dispatch call via function in a namespace or method on Tensor."""
|
|
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
|
dispatcher_exprs = dispatcher_sig.exprs()
|
|
|
|
# code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
|
|
# Ops also always have a function variant of the redispatch API.
|
|
# See Note [Plumbing Keys Through The Dispatcher] for details.
|
|
dispatch_key_set = "ks & c10::after_autograd_keyset"
|
|
call = CALL_REDISPATCH.substitute(
|
|
api_name=cpp.name(
|
|
f.func,
|
|
faithful_name_for_out_overloads=True,
|
|
symint_overload=f.func.has_symint(),
|
|
),
|
|
unpacked_args=[dispatch_key_set] + list(unpacked_args),
|
|
)
|
|
return call
|
|
|
|
def wrap_output(
|
|
f: NativeFunction, unpacked_bindings: List[Binding], var: str
|
|
) -> str:
|
|
call = ""
|
|
rhs_value: Optional[str] = None
|
|
if not any(r.type.is_tensor_like() for r in f.func.returns):
|
|
rhs_value = var
|
|
else:
|
|
rhs_value = f"std::move({var})"
|
|
assert rhs_value is not None
|
|
call += ASSIGN_RETURN_VALUE.substitute(
|
|
return_values=tie_return_values(f), rhs_value=rhs_value
|
|
)
|
|
return call
|
|
|
|
def check_tensorimpl_and_storage(
|
|
call: str, unpacked_bindings: List[Binding]
|
|
) -> str:
|
|
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
|
stmts_before_call: List[str] = []
|
|
stmts_after_call: List[str] = []
|
|
|
|
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
|
|
return call
|
|
|
|
# Check properties of inputs (enforce (1))
|
|
for unpacked_binding in unpacked_bindings:
|
|
arg = unpacked_binding.name
|
|
noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
|
|
if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType(
|
|
iTensorListRefT
|
|
):
|
|
stmts_before_call += [
|
|
SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
|
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
|
|
]
|
|
stmts_after_call += [
|
|
ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
|
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
|
|
]
|
|
elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
|
|
stmts_before_call += [
|
|
SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
|
SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg),
|
|
]
|
|
stmts_after_call += [
|
|
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
|
|
tensorlist_name=arg
|
|
),
|
|
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
|
|
tensorlist_name=arg
|
|
),
|
|
]
|
|
elif noref_cpp_type == BaseCType(tensorT):
|
|
stmts_before_call += [
|
|
SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
|
|
SAVE_TENSOR_IMPL.substitute(tensor_name=arg),
|
|
]
|
|
stmts_after_call += [
|
|
ENFORCE_SAME_TENSOR_STORAGE.substitute(
|
|
tensor_name=arg, out_tensor_name=arg
|
|
),
|
|
ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg),
|
|
]
|
|
|
|
assert (stmts_before_call and stmts_after_call) or (
|
|
not stmts_before_call and not stmts_after_call
|
|
)
|
|
|
|
# Check properties of outputs (enforce (2), (3))
|
|
if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out):
|
|
base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)?
|
|
aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
|
|
if aliased_arg_name is not None:
|
|
aliased_arg_name = unpacked_name(aliased_arg_name)
|
|
for i, (ret, ret_name) in enumerate(
|
|
zip(f.func.returns, cpp.return_names(f))
|
|
):
|
|
noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
|
|
if noref_cpp_type == BaseCType(tensorT):
|
|
if aliased_arg_name is not None:
|
|
assert (
|
|
i == 0
|
|
), "Expect non-CompositeImplicitAutograd view function {base} to return single output"
|
|
stmts_after_call += [
|
|
ENFORCE_SAME_TENSOR_STORAGE.substitute(
|
|
tensor_name=aliased_arg_name, out_tensor_name=ret_name
|
|
)
|
|
]
|
|
else:
|
|
if (
|
|
type_wrapper_name(f)
|
|
not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
|
|
):
|
|
stmts_after_call += [
|
|
ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
|
|
tensor_name=ret_name, fn_name=type_wrapper_name(f)
|
|
)
|
|
]
|
|
|
|
if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
|
|
stmts_after_call += [
|
|
ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
|
|
tensor_name=ret_name, fn_name=type_wrapper_name(f)
|
|
)
|
|
]
|
|
|
|
# Currently we don't have any functions that return the following types, but
|
|
# we should update the checks once we do
|
|
elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
|
|
raise AssertionError(
|
|
f"Please add use_count checks for {noref_cpp_type}"
|
|
)
|
|
elif noref_cpp_type == BaseCType(tensorListT):
|
|
raise AssertionError(
|
|
f"Please add use_count checks for {noref_cpp_type}"
|
|
)
|
|
|
|
if stmts_before_call and stmts_after_call:
|
|
call = (
|
|
RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call)
|
|
+ call
|
|
+ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
|
|
)
|
|
return call
|
|
|
|
def emit_call(
|
|
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
|
|
) -> str:
|
|
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
|
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
|
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
|
|
# in are now Variables.
|
|
# See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
|
|
unpacked_args = [b.name for b in unpacked_bindings]
|
|
base_type_call = emit_dispatch_call(f, "self_", unpacked_args)
|
|
|
|
if get_view_info(f) is not None or modifies_arguments(f):
|
|
guard = "at::AutoDispatchBelowAutograd guard;"
|
|
else:
|
|
guard = "at::AutoDispatchBelowADInplaceOrView guard;"
|
|
|
|
any_has_forward_grad = (
|
|
get_any_has_fw_grad_cond(derivative=None)
|
|
if requires_derivative
|
|
else "false"
|
|
)
|
|
return_types = ", ".join(
|
|
[cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
|
|
)
|
|
if len(f.func.returns) > 1:
|
|
return_types = f"std::tuple<{return_types}>"
|
|
|
|
arg_names = [
|
|
a.name
|
|
for a in cpp.arguments(
|
|
f.func.arguments,
|
|
faithful=True,
|
|
symint=True,
|
|
method=False,
|
|
cpp_no_default_args=set(),
|
|
)
|
|
]
|
|
|
|
if not modifies_arguments(f) and not returns_void:
|
|
if try_jit_decomposition:
|
|
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute(
|
|
base_type_call=base_type_call,
|
|
tmp_var=TMP_VAR,
|
|
guard=guard,
|
|
any_has_forward_grad=any_has_forward_grad,
|
|
op_name=cpp.name(f.func),
|
|
op_overload=f.func.name.overload_name,
|
|
return_types=return_types,
|
|
arg_names=arg_names,
|
|
)
|
|
else:
|
|
call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
|
|
base_type_call=base_type_call,
|
|
tmp_var=TMP_VAR,
|
|
guard=guard,
|
|
)
|
|
|
|
call += wrap_output(f, unpacked_bindings, TMP_VAR)
|
|
else:
|
|
assert not try_jit_decomposition
|
|
call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
|
|
base_type_call=base_type_call, guard=guard
|
|
)
|
|
call = check_tensorimpl_and_storage(call, unpacked_bindings)
|
|
return call
|
|
|
|
def emit_history() -> str:
|
|
fn = "rebase" if modifies_arguments(f) and view_info is None else "set"
|
|
output_names = [r.name for r in differentiable_outputs]
|
|
# TODO: flatten allocates a std::vector, which could be expensive
|
|
outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
|
|
outs=output_names if not is_inplace_foreach else "self"
|
|
)
|
|
if not is_inplace_foreach:
|
|
return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
|
|
else:
|
|
return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
|
|
preamble=(
|
|
f"auto differentiable_outputs = {outs};\n"
|
|
f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());"
|
|
),
|
|
statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);",
|
|
)
|
|
|
|
def emit_save_outputs() -> str:
|
|
if is_out_fn:
|
|
# out functions don't currently support differentiation
|
|
return ""
|
|
if info is not None and info.has_derivatives:
|
|
stmts = save_variables(info.all_saved_outputs, True)
|
|
if len(stmts) == 0:
|
|
return ""
|
|
if not is_inplace_foreach:
|
|
return CONDITIONAL.substitute(cond="grad_fn", statements=stmts)
|
|
else:
|
|
return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
|
|
preamble="", statements=stmts
|
|
)
|
|
return ""
|
|
|
|
def emit_any_requires_grad() -> List[str]:
|
|
extra_condition = ""
|
|
if info and info.output_differentiability_conditions:
|
|
assert len(info.output_differentiability_conditions) == 1
|
|
extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});"
|
|
names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives]
|
|
if is_inplace_foreach and info is not None:
|
|
for i, arg in enumerate(names_of_args_with_derivatives):
|
|
for f_arg, r_arg in inplace_foreacharg2refarg.items():
|
|
if arg == r_arg.name:
|
|
names_of_args_with_derivatives[i] = f_arg.name
|
|
return [
|
|
SETUP_ANY_REQUIRES_GRAD.substitute(
|
|
args_with_derivatives=names_of_args_with_derivatives,
|
|
extra_differentiability_conditions=extra_condition,
|
|
)
|
|
]
|
|
|
|
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
|
|
if len(var_names) == 1:
|
|
return f"_any_has_forward_grad_{var_names[0]}"
|
|
else:
|
|
return f'_any_has_forward_grad_{"_".join(var_names)}'
|
|
|
|
def emit_any_has_forward_grad() -> List[str]:
|
|
content: List[str] = []
|
|
if not is_foreach:
|
|
for derivative in fw_derivatives:
|
|
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
|
if info and info.output_differentiability_conditions:
|
|
assert len(info.output_differentiability_conditions) == 1
|
|
requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
|
|
content.append(
|
|
f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};"
|
|
)
|
|
else:
|
|
for derivative in fw_derivatives:
|
|
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
|
|
cur_derivative_conditions = [
|
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
|
req_inp=(
|
|
inp.name
|
|
if not inplace
|
|
else refargname2inplace_foreacharg[inp.name].name
|
|
)
|
|
+ (
|
|
"[i]"
|
|
if is_tensor_list_type(
|
|
inp.type
|
|
if not inplace
|
|
else refargname2inplace_foreacharg[inp.name].type
|
|
)
|
|
else ""
|
|
),
|
|
)
|
|
for inp in differentiable_inputs
|
|
if derivative.required_inputs_fw_grad is not None
|
|
and inp.name in derivative.required_inputs_fw_grad
|
|
]
|
|
content.append(f"std::vector<bool> {bool_vector_name}(self.size());")
|
|
content.append("for (const auto& i : c10::irange(self.size())) {")
|
|
content.append(
|
|
f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};"
|
|
)
|
|
content.append("}")
|
|
return content
|
|
|
|
def emit_check_inplace() -> List[str]:
|
|
if not inplace:
|
|
return []
|
|
return [
|
|
f"check_inplace({arg.name}, _any_requires_grad);"
|
|
for arg in differentiable_outputs
|
|
]
|
|
|
|
def emit_fw_derivatives() -> List[str]:
|
|
content: List[str] = []
|
|
fw_grad_setters: List[str] = []
|
|
for derivative in fw_derivatives:
|
|
res = derivative.var_names
|
|
if f.func.name.name.inplace:
|
|
assert (
|
|
len(res) == 1
|
|
), "Expected number of outputs to be 1 if function is inplace"
|
|
# TODO update this when inplace namings are unified
|
|
res = ("self",)
|
|
|
|
assert derivative.required_inputs_fw_grad is not None
|
|
|
|
unpacked_arguments = ""
|
|
for inp in differentiable_inputs:
|
|
inp_name = inp.name
|
|
is_input_tensorlist = is_foreach and is_tensor_list_type(
|
|
inp.type
|
|
if not inplace
|
|
else refargname2inplace_foreacharg[inp.name].type
|
|
)
|
|
input_suffix = "[i]" if is_input_tensorlist else ""
|
|
if is_inplace_foreach:
|
|
if inp.name in refargname2inplace_foreacharg:
|
|
inp_name = refargname2inplace_foreacharg[inp.name].name
|
|
zeros_fn = (
|
|
"zeros"
|
|
if inplace and inp.name == "self"
|
|
else "_efficientzerotensor"
|
|
)
|
|
if inp.name in derivative.required_inputs_fw_grad:
|
|
unpacked_arguments += (
|
|
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
|
|
inp_name=inp.name,
|
|
inp=inp_name + input_suffix,
|
|
zeros_fn=zeros_fn,
|
|
)
|
|
)
|
|
if inp.name in (derivative.required_inputs_primal or []):
|
|
unpacked_arguments += (
|
|
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
|
|
inp_name=inp.name,
|
|
inp=inp_name + input_suffix,
|
|
)
|
|
)
|
|
if derivative.required_original_self_value:
|
|
input_suffix = "s[i]" if is_inplace_foreach else ""
|
|
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
|
|
inp_name="original_self",
|
|
inp="original_self" + input_suffix,
|
|
zeros_fn=zeros_fn,
|
|
)
|
|
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
|
|
inp_name="original_self",
|
|
inp="original_self" + input_suffix,
|
|
)
|
|
elif inplace and derivative.is_reusing_outplace_formula:
|
|
# The gradient wasn't already cloned, do it if grad mode is enabled
|
|
unpacked_arguments += (
|
|
"self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
|
|
)
|
|
|
|
if inplace:
|
|
is_inplace_str = "true"
|
|
else:
|
|
is_inplace_str = "false"
|
|
|
|
requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names)
|
|
|
|
if all(
|
|
(isinstance(var_type, BaseType) and var_type.is_tensor_like())
|
|
for var_type in derivative.var_types
|
|
):
|
|
# Is there a way to get from BaseType to BaseCType
|
|
if len(derivative.var_types) == 1:
|
|
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
|
if not is_foreach:
|
|
fw_grad_setters.append(
|
|
FW_DERIVATIVE_SETTER_TENSOR.substitute(
|
|
out_arg=res[0], is_inplace=is_inplace_str
|
|
)
|
|
)
|
|
else:
|
|
assert res[0] == ("result" if not inplace else "self")
|
|
fw_grad_setters.append(
|
|
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
|
out_arg=res[0], is_inplace=is_inplace_str
|
|
)
|
|
)
|
|
requires_fw_grad += f" && ({derivative.var_names[0]}.defined())"
|
|
else:
|
|
tuple_type = TupleCType(
|
|
[BaseCType(tensorT)] * len(derivative.var_types)
|
|
)
|
|
opt_res_grad_type = OptionalCType(tuple_type).cpp_type()
|
|
for idx, single_res in enumerate(res):
|
|
fw_grad_setters.append(
|
|
FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(
|
|
idx=idx, all_res="_".join(res), out_arg=single_res
|
|
)
|
|
)
|
|
elif (
|
|
isinstance(derivative.var_types[0], ListType)
|
|
and derivative.var_types[0].is_tensor_like()
|
|
):
|
|
assert (
|
|
len(derivative.var_types) == 1
|
|
), "Expected number of outputs to be 1 if function returns ListType"
|
|
if not is_foreach:
|
|
opt_res_grad_type = OptionalCType(
|
|
VectorCType(BaseCType(tensorT))
|
|
).cpp_type()
|
|
fw_grad_setters.append(
|
|
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
|
|
out_arg=res[0], is_inplace=is_inplace_str
|
|
)
|
|
)
|
|
else:
|
|
# TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
|
|
# Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
|
|
# can reach here.
|
|
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
|
fw_grad_setters.append(
|
|
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
|
out_arg=res[0], is_inplace=is_inplace_str
|
|
)
|
|
)
|
|
else:
|
|
raise RuntimeError("Unsupported output type for forward derivative")
|
|
|
|
if not is_foreach:
|
|
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
|
|
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
|
|
content.append(
|
|
FW_DERIVATIVE_TEMPLATE.substitute(
|
|
fw_grad_opt_definition=fw_grad_opt_definition,
|
|
requires_fw_grad=requires_fw_grad,
|
|
formula=derivative.formula,
|
|
out_arg="_".join(res),
|
|
unpacked_arguments=unpacked_arguments,
|
|
)
|
|
)
|
|
else:
|
|
# note(crcrpar): Assuming `self` is TensorList.
|
|
fw_grad_opt_definition = (
|
|
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
|
|
"(self.size(), c10::nullopt);"
|
|
)
|
|
foreach_forward_grad_formula = derivative.formula
|
|
_foreach_arg: Union[Argument, DifferentiableInput]
|
|
if inplace:
|
|
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
|
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
|
if not (
|
|
is_tensor_type(_foreach_arg.type)
|
|
or is_tensor_list_type(_foreach_arg.type)
|
|
):
|
|
pattern = _foreach_arg.name
|
|
if isinstance(_foreach_arg.type, ListType):
|
|
pattern += "[i]"
|
|
foreach_forward_grad_formula = (
|
|
foreach_forward_grad_formula.replace(
|
|
_ref_arg.name, pattern
|
|
)
|
|
)
|
|
else:
|
|
if (
|
|
"result" in foreach_forward_grad_formula
|
|
and "result[i]" not in foreach_forward_grad_formula
|
|
):
|
|
foreach_forward_grad_formula = (
|
|
foreach_forward_grad_formula.replace("result", "result[i]")
|
|
)
|
|
|
|
content.append(
|
|
FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
|
|
fw_grad_opt_definition=fw_grad_opt_definition,
|
|
vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts",
|
|
any_has_forward_grad_for_current_index=" || ".join(
|
|
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
|
|
for derivative in fw_derivatives
|
|
),
|
|
formula=foreach_forward_grad_formula,
|
|
unpacked_arguments=unpacked_arguments,
|
|
)
|
|
)
|
|
|
|
# Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
|
|
content.append("\n".join(fw_grad_setters))
|
|
return content
|
|
|
|
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
|
|
#
|
|
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
|
#
|
|
if derivative is None:
|
|
# (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
|
|
# - Used in the out_fn case when we want to forbid fw derivatives
|
|
# - Used in the case where the fw_derivative is not defined, but we want
|
|
# To check if there is a decomposition registered for jvp
|
|
to_check: List[str] = []
|
|
for inp in list(
|
|
mapMaybe(
|
|
gen_differentiable_input,
|
|
f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
|
|
)
|
|
):
|
|
if is_tensor_type(inp.type):
|
|
to_check.append(
|
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
|
)
|
|
elif is_tensor_list_type(inp.type):
|
|
to_check.append(
|
|
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
|
|
req_inp=inp.name
|
|
)
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
f'Unsupported input type for "{name}" when forbidding forward AD usage.'
|
|
)
|
|
return f'({" || ".join(to_check)})'
|
|
else:
|
|
# (2) If derivative is provided, use that information to determine which inputs
|
|
# to check fw_grad for
|
|
assert derivative.required_inputs_fw_grad is not None
|
|
|
|
if len(derivative.required_inputs_fw_grad) == 0:
|
|
# Handle functions like stack
|
|
# For these, we don't unpack anything and always call the user function
|
|
if not (
|
|
len(differentiable_inputs) == 1
|
|
and is_tensor_list_type(differentiable_inputs[0].type)
|
|
):
|
|
raise RuntimeError(
|
|
f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
|
|
"forward AD formula does not use any input tangent) even though a forward gradient "
|
|
"formula has been defined for it. This case should only happen for function that "
|
|
"take a single TensorList as input. All other cases are not supported right now."
|
|
)
|
|
any_has_fw_grad = "true"
|
|
else:
|
|
any_has_fw_grad = " || ".join(
|
|
[
|
|
(
|
|
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
|
|
if is_tensor_list_type(inp.type)
|
|
else FW_DERIVATIVE_CHECK_TEMPLATE
|
|
).substitute(req_inp=inp.name)
|
|
for inp in differentiable_inputs
|
|
if inp.name in derivative.required_inputs_fw_grad
|
|
]
|
|
)
|
|
any_has_fw_grad = f"({any_has_fw_grad})"
|
|
|
|
return any_has_fw_grad
|
|
|
|
def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
|
|
if is_out_fn:
|
|
msg = "because it is an out= function"
|
|
else:
|
|
msg = (
|
|
"because it has not been implemented yet.\\nPlease file an issue "
|
|
"to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
|
|
"so that we can prioritize its implementation."
|
|
)
|
|
cond = get_any_has_fw_grad_cond(derivative=None)
|
|
return (
|
|
FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
|
|
if cond != ""
|
|
else ""
|
|
)
|
|
|
|
body: List[str] = []
|
|
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
|
|
|
body.extend(unpack_args_stats)
|
|
if requires_derivative:
|
|
body.extend(emit_any_requires_grad())
|
|
body.extend(emit_any_has_forward_grad())
|
|
body.extend(emit_check_inplace())
|
|
body.extend(emit_original_self_definition())
|
|
body.extend(setup_derivative(differentiable_inputs))
|
|
body.append(declare_returned_variables(f))
|
|
|
|
body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
|
|
if requires_derivative:
|
|
# set_flags has to appear after version_counter, because rebase_history
|
|
# requires that the counter is incremented before it is called
|
|
body.append(emit_history())
|
|
body.extend(emit_check_if_in_complex_autograd_allowlist())
|
|
|
|
if is_out_fn:
|
|
body.append(emit_forbid_fw_derivatives(is_out_fn=True))
|
|
else:
|
|
if requires_derivative and not try_jit_decomposition:
|
|
if len(fw_derivatives) > 0:
|
|
body.extend(emit_fw_derivatives())
|
|
else:
|
|
body.append(emit_forbid_fw_derivatives())
|
|
|
|
if requires_derivative:
|
|
# Save only after the forward AD has been set up
|
|
body.append(emit_save_outputs())
|
|
|
|
if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
|
|
# `inplace` implies that there is exactly one output named `self`,
|
|
# so we can keep the generated code easy. If you need to
|
|
# `reset_grad_accumulator` in an operator that's not `inplace`, you can
|
|
# remove this assert but the code generation will get more elaborate
|
|
assert inplace
|
|
body.append("reset_grad_accumulator(self);")
|
|
if not returns_void:
|
|
body.append(f"return {get_return_value(f)};")
|
|
return body
|