243 lines
9.5 KiB
C++
243 lines
9.5 KiB
C++
#pragma once
|
|
#include <ATen/ThreadLocalState.h>
|
|
#include <ATen/core/Tensor.h>
|
|
#include <c10/util/ThreadLocal.h>
|
|
#include <torch/csrc/autograd/input_buffer.h>
|
|
#include <torch/csrc/autograd/utils/warnings.h>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
using edge_list = std::vector<Edge>;
|
|
struct ReadyQueue;
|
|
|
|
static constexpr int NO_DEVICE = -2;
|
|
static constexpr int CPU_DEVICE = -1;
|
|
|
|
namespace {
|
|
std::atomic<uint64_t> graph_task_id{0};
|
|
}
|
|
|
|
// GraphTask holds metadata needed for a single execution of backward()
|
|
struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
|
std::atomic<uint64_t> outstanding_tasks_{0};
|
|
// Indicates if an error occurred while executing any task. When this is
|
|
// true, it signals all threads to stop executing.
|
|
std::atomic_bool has_error_{false};
|
|
std::atomic_bool future_completed_{false};
|
|
// It is safe to read keep_graph_ without synchronization
|
|
bool keep_graph_;
|
|
|
|
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
|
|
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
|
|
std::mutex mutex_;
|
|
std::unordered_map<Node*, InputBuffer> not_ready_;
|
|
std::unordered_map<Node*, int> dependencies_;
|
|
|
|
// Records the nodes that are in the graph
|
|
std::unordered_set<Node*> nodes_in_graph_;
|
|
c10::SmallVector<Node*, 4> graph_roots_;
|
|
// Note [Exec info]
|
|
// Exec info is created for each GraphTask, which allows filtering paths on
|
|
// the graph that are not needed. It has a bit complicated semantics. If it's
|
|
// empty, it means the task is run in a "default" mode, which means that all
|
|
// next_edges we encounter should get executed. If it's not empty, only
|
|
// functions that have an entry and this entry has needed == True should be
|
|
// executed. exec_info is only empty when the graph is executed via
|
|
// .backward() and the inputs parameter is not passed. Otherwise, when
|
|
// executed through .grad(), or when inputs arg is specified for .backward(),
|
|
// exec_info will be non-empty.
|
|
//
|
|
struct ExecInfo {
|
|
struct Capture {
|
|
Capture(const Capture&) = delete;
|
|
Capture(Capture&&) = default;
|
|
|
|
Capture(int input_idx, int output_idx)
|
|
: input_idx_(input_idx), output_idx_(output_idx) {}
|
|
int input_idx_; // within Node inputs
|
|
int output_idx_; // within the output vector of a GraphTask
|
|
|
|
// This hook will be executed after a grad is captured. The captured
|
|
// grad will be replaced by the return value of the hook.
|
|
struct GradCaptureHook {
|
|
virtual ~GradCaptureHook() = default;
|
|
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
|
|
};
|
|
// NOTE [Deprecated capture hooks]
|
|
//
|
|
// The current status of capture hooks is that we continue to support
|
|
// the single usage of it by distributed in the dist_engine. If anyone
|
|
// else needs to use it for other purposes, they should file an issue.
|
|
//
|
|
// Capture hooks were originally created because there did not exist
|
|
// any way to register pre/post hooks to grad_fn in a way such that it
|
|
// would still be executed even if that is the grad_fn of a Tensor
|
|
// passed as input= of .grad. As far as I know, only dist_engine uses
|
|
// this hook.
|
|
//
|
|
// However, there are other alternatives today like tensor hooks that can
|
|
// replace the usage that originally motivated its creation. Also,
|
|
// Captures hooks are an outlier in terms of the types of hook that
|
|
// autograd offers in how it is registered and behaves, e.g. it is a hook
|
|
// registered not to the graph, but to a particular graph_task! This makes
|
|
// it a burden to maintain.
|
|
//
|
|
// It would be very nice to clean up/do a migration from pre/post
|
|
// hooks used in distributed to use tensor hooks, but for now we just
|
|
// mark this method as deprecated to prevent additional usage.
|
|
//
|
|
// If you still think you really need to capture hooks, please file an
|
|
// issue (and tag autograd).
|
|
const std::vector<std::unique_ptr<GradCaptureHook>>&
|
|
DO_NOT_USE_DEPRECATED_get_capture_hooks() const {
|
|
return hooks_;
|
|
}
|
|
// See NOTE [deprecated capture hooks]
|
|
void DO_NOT_USE_DEPRECATED_register_capture_hook(
|
|
std::unique_ptr<GradCaptureHook> hook) {
|
|
hooks_.push_back(std::move(hook));
|
|
}
|
|
|
|
private:
|
|
// The hooks will be called one by one in the order as they were added.
|
|
// The input grad of a hook will be the output of its preceding hook. The
|
|
// first hook will take the captured grad as the input. The output of the
|
|
// last hook will replace the captured grad.
|
|
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
|
|
};
|
|
|
|
bool should_execute() const {
|
|
return needed_ || captures_;
|
|
}
|
|
|
|
bool needed_ = false;
|
|
std::unique_ptr<std::vector<Capture>> captures_;
|
|
};
|
|
// exec_info_ is safe to read without synchronization
|
|
std::unordered_map<Node*, ExecInfo> exec_info_;
|
|
// Captures variables are grads captured that we return to the user. After
|
|
// execution of the GraphTask is completed, the captured_vars_ are moved
|
|
// out of the GraphTask and are no longer valid.
|
|
std::vector<Variable> captured_vars_;
|
|
|
|
// Note: this field is not ready to be used until the proper
|
|
// `thread_locals_.set_grad_mode()` call in the constructor.
|
|
at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
|
|
|
|
std::unordered_set<c10::Stream> leaf_streams;
|
|
|
|
// Per-device current streams of the execute() that called this GraphTask.
|
|
// These will be synced with leaf_streams in exec_post_processing.
|
|
std::vector<c10::optional<c10::Stream>> caller_current_streams_;
|
|
|
|
// Collects caller_current_streams_
|
|
void stash_current_streams();
|
|
|
|
void init_to_execute(
|
|
Node& graph_root,
|
|
const edge_list& outputs,
|
|
bool accumulate_grad,
|
|
uint64_t min_topo_nr);
|
|
|
|
// The value of worker_device in the thread that created this task.
|
|
// See Note [Reentrant backwards]
|
|
// Safe to read owner_ and reentrant_depth_ without synchronization
|
|
int owner_;
|
|
// The number of parent graph tasks for this graph task
|
|
const int reentrant_depth_;
|
|
|
|
bool can_checkpoint() const {
|
|
return exec_info_.empty();
|
|
}
|
|
|
|
// check if the GraphTask is completed or not
|
|
bool completed();
|
|
// mark the graph task as completed and trigger post processing
|
|
void mark_as_completed_and_run_post_processing();
|
|
|
|
// Set an appropriate exception on this graph_task which was encountered while
|
|
// running the provided function.
|
|
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);
|
|
|
|
// Set an appropriate exception on this graph_task which was encountered while
|
|
// running the provided function. But doesn't signal completion on
|
|
// 'future_result_' right away. The user needs to explicitly mark
|
|
// 'future_result_' completed with an appropriate exception.
|
|
void set_exception_without_signal(const std::shared_ptr<Node>& fn);
|
|
|
|
// Whether or not to stop execution for this GraphTask when an error is
|
|
// encountered. When set to true, this would cause Engine::execute() to throw
|
|
// an exception as soon as the autograd engine receives an exception.
|
|
bool exit_on_error_;
|
|
|
|
// CPU threads are dedicated to processing CPU work for the backward they
|
|
// invoked. So any given graph task maintains its own cpu_ready_queue_ where
|
|
// you should send work for it to be done. We memoize the cpu_ready_queue_ per
|
|
// GraphTask so that we know which ready queue we should push to if we are on
|
|
// device thread (i.e. GPU) and but next NodeTask should be run on CPU.
|
|
std::shared_ptr<ReadyQueue> cpu_ready_queue_;
|
|
|
|
// Future representing the completion of the graph task. Notified when all
|
|
// tasks are done.
|
|
c10::intrusive_ptr<at::ivalue::Future> future_result_;
|
|
|
|
// Final callbacks installed during execution of this GraphTask
|
|
std::vector<std::function<void()>> final_callbacks_;
|
|
// To protect reads and writes to final_callbacks_. Intentionally no reusing
|
|
// mutex_ as the two are protecting different data structures.
|
|
std::mutex final_callbacks_lock_;
|
|
|
|
utils::DelayWarningHandler warning_handler_;
|
|
|
|
uint64_t id_;
|
|
|
|
GraphTask(
|
|
bool keep_graph,
|
|
bool grad_mode,
|
|
int reentrant_depth,
|
|
std::shared_ptr<ReadyQueue> cpu_ready_queue,
|
|
c10::SmallVector<Node*, 4> graph_roots,
|
|
bool exit_on_error = false)
|
|
: keep_graph_(keep_graph),
|
|
graph_roots_(std::move(graph_roots)),
|
|
owner_(NO_DEVICE),
|
|
reentrant_depth_(reentrant_depth),
|
|
exit_on_error_(exit_on_error),
|
|
cpu_ready_queue_(std::move(cpu_ready_queue)),
|
|
future_result_(c10::make_intrusive<at::ivalue::Future>(
|
|
c10::ListType::create(c10::TensorType::get()))),
|
|
id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) {
|
|
thread_locals_.set_grad_mode(grad_mode);
|
|
}
|
|
|
|
private:
|
|
// run GraphTask post processing
|
|
void exec_post_processing();
|
|
};
|
|
|
|
// The guard that sets and restores current_graph_task.
|
|
class GraphTaskGuard {
|
|
public:
|
|
explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
|
|
~GraphTaskGuard();
|
|
|
|
void restore_current_graph_task();
|
|
|
|
private:
|
|
std::shared_ptr<GraphTask> last_graph_task_;
|
|
};
|
|
|
|
TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
|
|
get_current_graph_task_exec_info();
|
|
TORCH_API const std::unordered_set<Node*>*
|
|
get_current_graph_task_nodes_in_graph();
|
|
TORCH_API bool get_current_graph_task_keep_graph();
|
|
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
|
|
TORCH_API int get_current_graph_task_id();
|
|
void add_node_to_current_graph_task_exec_info(Node* fn);
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|