196 lines
7.4 KiB
C++
196 lines
7.4 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/impl/InlineDeviceGuard.h>
|
|
|
|
namespace c10 {
|
|
|
|
/// RAII guard that sets a certain default device in its constructor, and
|
|
/// changes it back to the device that was originally active upon destruction.
|
|
///
|
|
/// The device is always reset to the one that was active at the time of
|
|
/// construction of the guard. Even if you `set_device` after construction, the
|
|
/// destructor will still reset the device to the one that was active at
|
|
/// construction time.
|
|
///
|
|
/// This device guard does NOT have an uninitialized state; it is guaranteed
|
|
/// to reset a device on exit. If you are in a situation where you *might*
|
|
/// want to setup a guard (i.e., are looking for the moral equivalent
|
|
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
|
|
class DeviceGuard {
|
|
public:
|
|
/// No default constructor; see Note [Omitted default constructor from RAII]
|
|
explicit DeviceGuard() = delete;
|
|
|
|
/// Set the current device to the passed Device.
|
|
explicit DeviceGuard(Device device) : guard_(device) {}
|
|
|
|
/// This constructor is for testing only.
|
|
explicit DeviceGuard(
|
|
Device device,
|
|
const impl::DeviceGuardImplInterface* impl)
|
|
: guard_(device, impl) {}
|
|
|
|
/// Copy is disallowed
|
|
DeviceGuard(const DeviceGuard&) = delete;
|
|
DeviceGuard& operator=(const DeviceGuard&) = delete;
|
|
|
|
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
|
|
/// which is required for moves on types with nontrivial destructors.
|
|
DeviceGuard(DeviceGuard&& other) = delete;
|
|
DeviceGuard& operator=(DeviceGuard&& other) = delete;
|
|
|
|
/// Sets the device to the given one. The specified device must be consistent
|
|
/// with the device type originally specified during guard construction.
|
|
///
|
|
/// TODO: The consistency check here is inconsistent with StreamGuard's
|
|
/// behavior with set_stream, where a stream on a different device than
|
|
/// the original one isn't an error; we just reset the stream and then
|
|
/// switch devices.
|
|
void reset_device(at::Device device) {
|
|
guard_.reset_device(device);
|
|
}
|
|
|
|
/// This method is for testing only.
|
|
void reset_device(
|
|
at::Device device,
|
|
const impl::DeviceGuardImplInterface* impl) {
|
|
guard_.reset_device(device, impl);
|
|
}
|
|
|
|
/// Sets the device index to the given one. The device type is inferred
|
|
/// from the original device type the guard was constructed with.
|
|
void set_index(DeviceIndex index) {
|
|
guard_.set_index(index);
|
|
}
|
|
|
|
/// Returns the device that was set at the time the guard was constructed.
|
|
Device original_device() const {
|
|
return guard_.original_device();
|
|
}
|
|
|
|
/// Returns the most recent device that was set using this device guard,
|
|
/// either from construction, or via set_device.
|
|
Device current_device() const {
|
|
return guard_.current_device();
|
|
}
|
|
|
|
private:
|
|
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
|
|
};
|
|
|
|
/**
|
|
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
|
|
* initialization, and resets the device to its original value on destruction.
|
|
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
|
|
* with extra constructors and methods as appropriate.
|
|
*
|
|
* Besides its obvious use (optionally applying a DeviceGuard),
|
|
* OptionalDeviceGuard is often also used for the following idiom:
|
|
*
|
|
* OptionalDeviceGuard g;
|
|
* for (const auto& t : tensors) {
|
|
* g.set_device(t.device());
|
|
* do_something_with(t);
|
|
* }
|
|
*
|
|
* This usage is marginally more efficient than constructing a DeviceGuard every
|
|
* iteration of the for loop, as it avoids an unnecessary device reset.
|
|
*
|
|
* Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
|
|
* when you use the nullary constructor, or pass a nullopt to the constructor.
|
|
* Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
|
|
* original device was and they do not reset on destruction. This is why
|
|
* original_device() and current_device() return optional<Device> rather than
|
|
* Device (as they do in DeviceGuard), and also is why we didn't just
|
|
* provide OptionalDeviceGuard by default and hide DeviceGuard from users.
|
|
*
|
|
* The semantics of an OptionalDeviceGuard are exactly explained by thinking
|
|
* of it as an optional<DeviceGuard>. In particular, an initialized
|
|
* OptionalDeviceGuard doesn't restore device to its value at construction; it
|
|
* restores device to its value *at initialization*. So if you have the
|
|
* program:
|
|
*
|
|
* setDevice(1);
|
|
* OptionalDeviceGuard g;
|
|
* setDevice(2);
|
|
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
|
|
*
|
|
* On destruction, g will reset device to 2, rather than 1.
|
|
*
|
|
* An uninitialized OptionalDeviceGuard is distinct from a (initialized)
|
|
* DeviceGuard whose original_device_ and current_device_ match, since the
|
|
* DeviceGuard will still reset the device to original_device_.
|
|
*/
|
|
class OptionalDeviceGuard {
|
|
public:
|
|
/// Create an uninitialized guard. Set the guard later using reset_device.
|
|
explicit OptionalDeviceGuard() = default;
|
|
|
|
/// Initialize the guard, setting the current device to the passed Device.
|
|
explicit OptionalDeviceGuard(Device device) : guard_(device) {}
|
|
|
|
/// Initialize the guard if a Device is passed; otherwise leave the
|
|
/// guard uninitialized.
|
|
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
|
|
|
|
/// Constructor for testing only.
|
|
explicit OptionalDeviceGuard(
|
|
Device device,
|
|
const impl::DeviceGuardImplInterface* impl)
|
|
: guard_(device, impl) {}
|
|
|
|
/// Copy is disallowed
|
|
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
|
|
OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
|
|
|
|
/// Move is disallowed
|
|
/// See Note [Explicit initialization of optional fields]
|
|
/// and // Note [Move construction for RAII guards is tricky]
|
|
/// for rationale.
|
|
OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
|
|
OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
|
|
|
|
/// Sets the device to the given one. The specified device must be consistent
|
|
/// with the device type originally specified during guard construction.
|
|
void reset_device(at::Device device) {
|
|
guard_.reset_device(device);
|
|
}
|
|
|
|
/// For testing only
|
|
void reset_device(
|
|
at::Device device,
|
|
const impl::DeviceGuardImplInterface* impl) {
|
|
guard_.reset_device(device, impl);
|
|
}
|
|
|
|
/// Returns the device that was set at the time the guard was constructed.
|
|
optional<Device> original_device() const {
|
|
return guard_.original_device();
|
|
}
|
|
|
|
/// Returns the most recent device that was set using this device guard,
|
|
/// either from construction, or via reset_device.
|
|
optional<Device> current_device() const {
|
|
return guard_.current_device();
|
|
}
|
|
|
|
private:
|
|
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_{};
|
|
};
|
|
|
|
// Note [Whither the DeviceGuard boilerplate]
|
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
// Design note: in principle, we could avoid these wrappers using:
|
|
//
|
|
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
|
|
// using OptionalDeviceGuard =
|
|
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
|
|
//
|
|
// But the error messages are worse, and our users can't just look at the
|
|
// header file to find out what's going on. Furthermore, for specializations
|
|
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
|
|
// refined types (e.g., return CUDAStream instead of Stream). So, we eat
|
|
// the boilerplate and write out the API explicitly.
|
|
|
|
} // namespace c10
|