351 lines
9.8 KiB
C++
351 lines
9.8 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/DeviceType.h>
|
|
#include <c10/core/DispatchKey.h>
|
|
#include <c10/core/DispatchKeySet.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <stdexcept>
|
|
|
|
namespace c10 {
|
|
|
|
/**
|
|
* This legacy enum class defines the set of backends supported by old school,
|
|
* code generated Type-based ATen. A "backend" in this sense roughly
|
|
* corresponds to the cartesian product of (device type, layout), but restricted
|
|
* only to combinations which we actually have kernels for. Backend does NOT
|
|
* include dtype.
|
|
*
|
|
* The reason we are sunsetting this enum class is because it doesn't allow for
|
|
* open registration; e.g., if you want to add SparseXLA, you'd have to
|
|
* edit this enum; you wouldn't be able to do it out of tree. DispatchKey is
|
|
* the replacement for Backend which supports open registration.
|
|
*
|
|
* NB: The concept of 'Backend' here disagrees with the notion of backend
|
|
* exposed to users in torch.backends. Backend here is something like "CPU"
|
|
* or "SparseCUDA"; backend in torch.backends is something like "MKL" or
|
|
* "CUDNN".
|
|
*/
|
|
enum class Backend {
|
|
CPU,
|
|
CUDA,
|
|
HIP,
|
|
VE,
|
|
FPGA,
|
|
IPU,
|
|
XPU,
|
|
SparseCPU,
|
|
SparseCUDA,
|
|
SparseCsrCPU,
|
|
SparseCsrCUDA,
|
|
SparseHIP,
|
|
SparseVE,
|
|
SparseXPU,
|
|
SparsePrivateUse1,
|
|
ORT,
|
|
XLA,
|
|
Vulkan,
|
|
Metal,
|
|
Meta,
|
|
QuantizedCPU,
|
|
QuantizedCUDA,
|
|
QuantizedXPU,
|
|
QuantizedPrivateUse1,
|
|
Undefined,
|
|
MkldnnCPU,
|
|
MPS,
|
|
HPU,
|
|
Lazy,
|
|
MTIA,
|
|
PrivateUse1,
|
|
NumOptions
|
|
};
|
|
|
|
static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
|
if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
|
|
return Backend::CPU;
|
|
} else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
|
|
return Backend::CUDA;
|
|
} else if (t == DispatchKey::HIP) {
|
|
return Backend::HIP;
|
|
} else if (t == DispatchKey::VE) {
|
|
return Backend::VE;
|
|
} else if (t == DispatchKey::FPGA) {
|
|
return Backend::FPGA;
|
|
} else if (t == DispatchKey::ORT) {
|
|
return Backend::ORT;
|
|
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
|
|
return Backend::XLA;
|
|
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
|
|
return Backend::Lazy;
|
|
} else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
|
|
return Backend::MPS;
|
|
} else if (t == DispatchKey::Vulkan) {
|
|
return Backend::Vulkan;
|
|
} else if (t == DispatchKey::Metal) {
|
|
return Backend::Metal;
|
|
} else if (t == DispatchKey::Meta) {
|
|
return Backend::Meta;
|
|
} else if (t == DispatchKey::SparseCPU) {
|
|
return Backend::SparseCPU;
|
|
} else if (t == DispatchKey::SparseCUDA) {
|
|
return Backend::SparseCUDA;
|
|
} else if (t == DispatchKey::SparseHIP) {
|
|
return Backend::SparseHIP;
|
|
} else if (t == DispatchKey::SparseVE) {
|
|
return Backend::SparseVE;
|
|
} else if (t == DispatchKey::SparsePrivateUse1) {
|
|
return Backend::SparsePrivateUse1;
|
|
} else if (t == DispatchKey::SparseCsrCPU) {
|
|
return Backend::SparseCsrCPU;
|
|
} else if (t == DispatchKey::SparseCsrCUDA) {
|
|
return Backend::SparseCsrCUDA;
|
|
} else if (t == DispatchKey::MkldnnCPU) {
|
|
return Backend::MkldnnCPU;
|
|
} else if (t == DispatchKey::QuantizedCPU) {
|
|
return Backend::QuantizedCPU;
|
|
} else if (t == DispatchKey::QuantizedCUDA) {
|
|
return Backend::QuantizedCUDA;
|
|
} else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
|
|
return Backend::IPU;
|
|
} else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
|
|
return Backend::XPU;
|
|
} else if (t == DispatchKey::SparseXPU) {
|
|
return Backend::SparseXPU;
|
|
} else if (t == DispatchKey::QuantizedXPU) {
|
|
return Backend::QuantizedXPU;
|
|
} else if (t == DispatchKey::QuantizedPrivateUse1) {
|
|
return Backend::QuantizedPrivateUse1;
|
|
} else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
|
|
return Backend::HPU;
|
|
} else if (t == DispatchKey::MTIA || t == DispatchKey::AutogradMTIA) {
|
|
return Backend::MTIA;
|
|
} else if (
|
|
t == DispatchKey::PrivateUse1 || t == DispatchKey::AutogradPrivateUse1) {
|
|
return Backend::PrivateUse1;
|
|
} else if (t == DispatchKey::Undefined) {
|
|
return Backend::Undefined;
|
|
} else {
|
|
TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
|
|
}
|
|
}
|
|
|
|
static inline DispatchKey backendToDispatchKey(Backend b) {
|
|
switch (b) {
|
|
case Backend::CPU:
|
|
return DispatchKey::CPU;
|
|
case Backend::CUDA:
|
|
return DispatchKey::CUDA;
|
|
case Backend::HIP:
|
|
return DispatchKey::HIP;
|
|
case Backend::VE:
|
|
return DispatchKey::VE;
|
|
case Backend::FPGA:
|
|
return DispatchKey::FPGA;
|
|
case Backend::ORT:
|
|
return DispatchKey::ORT;
|
|
case Backend::XLA:
|
|
return DispatchKey::XLA;
|
|
case Backend::Lazy:
|
|
return DispatchKey::Lazy;
|
|
case Backend::IPU:
|
|
return DispatchKey::IPU;
|
|
case Backend::XPU:
|
|
return DispatchKey::XPU;
|
|
case Backend::SparseXPU:
|
|
return DispatchKey::SparseXPU;
|
|
case Backend::SparseCPU:
|
|
return DispatchKey::SparseCPU;
|
|
case Backend::SparseCUDA:
|
|
return DispatchKey::SparseCUDA;
|
|
case Backend::SparseHIP:
|
|
return DispatchKey::SparseHIP;
|
|
case Backend::SparseVE:
|
|
return DispatchKey::SparseVE;
|
|
case Backend::SparsePrivateUse1:
|
|
return DispatchKey::SparsePrivateUse1;
|
|
case Backend::SparseCsrCPU:
|
|
return DispatchKey::SparseCsrCPU;
|
|
case Backend::SparseCsrCUDA:
|
|
return DispatchKey::SparseCsrCUDA;
|
|
case Backend::MkldnnCPU:
|
|
return DispatchKey::MkldnnCPU;
|
|
case Backend::Vulkan:
|
|
return DispatchKey::Vulkan;
|
|
case Backend::Metal:
|
|
return DispatchKey::Metal;
|
|
case Backend::Meta:
|
|
return DispatchKey::Meta;
|
|
case Backend::QuantizedCPU:
|
|
return DispatchKey::QuantizedCPU;
|
|
case Backend::QuantizedCUDA:
|
|
return DispatchKey::QuantizedCUDA;
|
|
case Backend::QuantizedPrivateUse1:
|
|
return DispatchKey::QuantizedPrivateUse1;
|
|
case Backend::Undefined:
|
|
return DispatchKey::Undefined;
|
|
case Backend::MPS:
|
|
return DispatchKey::MPS;
|
|
case Backend::HPU:
|
|
return DispatchKey::HPU;
|
|
case Backend::MTIA:
|
|
return DispatchKey::MTIA;
|
|
case Backend::PrivateUse1:
|
|
return DispatchKey::PrivateUse1;
|
|
default:
|
|
throw std::runtime_error("Unknown backend");
|
|
}
|
|
}
|
|
|
|
static inline DeviceType backendToDeviceType(Backend b) {
|
|
switch (b) {
|
|
case Backend::CPU:
|
|
case Backend::MkldnnCPU:
|
|
case Backend::SparseCPU:
|
|
case Backend::SparseCsrCPU:
|
|
case Backend::QuantizedCPU:
|
|
return DeviceType::CPU;
|
|
case Backend::CUDA:
|
|
case Backend::SparseCUDA:
|
|
case Backend::QuantizedCUDA:
|
|
case Backend::SparseCsrCUDA:
|
|
return DeviceType::CUDA;
|
|
case Backend::HIP:
|
|
return DeviceType::HIP;
|
|
case Backend::VE:
|
|
return DeviceType::VE;
|
|
case Backend::FPGA:
|
|
return DeviceType::FPGA;
|
|
case Backend::ORT:
|
|
return DeviceType::ORT;
|
|
case Backend::XLA:
|
|
return DeviceType::XLA;
|
|
case Backend::Lazy:
|
|
return DeviceType::Lazy;
|
|
case Backend::SparseHIP:
|
|
return DeviceType::HIP;
|
|
case Backend::SparseVE:
|
|
return DeviceType::VE;
|
|
case Backend::IPU:
|
|
return DeviceType::IPU;
|
|
case Backend::XPU:
|
|
case Backend::SparseXPU:
|
|
case Backend::QuantizedXPU:
|
|
return DeviceType::XPU;
|
|
case Backend::Vulkan:
|
|
return DeviceType::Vulkan;
|
|
case Backend::Metal:
|
|
return DeviceType::Metal;
|
|
case Backend::Meta:
|
|
return DeviceType::Meta;
|
|
case Backend::MPS:
|
|
return DeviceType::MPS;
|
|
case Backend::HPU:
|
|
return DeviceType::HPU;
|
|
case Backend::MTIA:
|
|
return DeviceType::MTIA;
|
|
case Backend::PrivateUse1:
|
|
case Backend::SparsePrivateUse1:
|
|
case Backend::QuantizedPrivateUse1:
|
|
return DeviceType::PrivateUse1;
|
|
case Backend::Undefined:
|
|
TORCH_CHECK(false, "Undefined backend is not a valid device type");
|
|
default:
|
|
TORCH_CHECK(false, "Unknown backend");
|
|
}
|
|
}
|
|
|
|
// TODO: This probably shouldn't actually be static inline
|
|
static inline const char* toString(Backend b) {
|
|
switch (b) {
|
|
case Backend::CPU:
|
|
return "CPU";
|
|
case Backend::CUDA:
|
|
return "CUDA";
|
|
case Backend::HIP:
|
|
return "HIP";
|
|
case Backend::VE:
|
|
return "VE";
|
|
case Backend::FPGA:
|
|
return "FPGA";
|
|
case Backend::XPU:
|
|
return "XPU";
|
|
case Backend::IPU:
|
|
return "IPU";
|
|
case Backend::ORT:
|
|
return "ORT";
|
|
case Backend::XLA:
|
|
return "XLA";
|
|
case Backend::Lazy:
|
|
return "Lazy";
|
|
case Backend::MPS:
|
|
return "MPS";
|
|
case Backend::SparseCPU:
|
|
return "SparseCPU";
|
|
case Backend::SparseCUDA:
|
|
return "SparseCUDA";
|
|
case Backend::SparseHIP:
|
|
return "SparseHIP";
|
|
case Backend::SparseVE:
|
|
return "SparseVE";
|
|
case Backend::SparseXPU:
|
|
return "SparseXPU";
|
|
case Backend::SparsePrivateUse1:
|
|
return "SparsePrivateUse1";
|
|
case Backend::SparseCsrCPU:
|
|
return "SparseCsrCPU";
|
|
case Backend::SparseCsrCUDA:
|
|
return "SparseCsrCUDA";
|
|
case Backend::MkldnnCPU:
|
|
return "MkldnnCPU";
|
|
case Backend::Vulkan:
|
|
return "Vulkan";
|
|
case Backend::Metal:
|
|
return "Metal";
|
|
case Backend::Meta:
|
|
return "Meta";
|
|
case Backend::QuantizedCPU:
|
|
return "QuantizedCPU";
|
|
case Backend::QuantizedCUDA:
|
|
return "QuantizedCUDA";
|
|
case Backend::QuantizedXPU:
|
|
return "QuantizedXPU";
|
|
case Backend::QuantizedPrivateUse1:
|
|
return "QuantizedPrivateUse1";
|
|
case Backend::HPU:
|
|
return "HPU";
|
|
case Backend::MTIA:
|
|
return "MTIA";
|
|
case Backend::PrivateUse1:
|
|
return "PrivateUseOne";
|
|
default:
|
|
return "UNKNOWN_BACKEND";
|
|
}
|
|
}
|
|
|
|
static inline bool isSparse(Backend b) {
|
|
switch (b) {
|
|
case Backend::SparseXPU:
|
|
case Backend::SparseCPU:
|
|
case Backend::SparseCUDA:
|
|
case Backend::SparseHIP:
|
|
case Backend::SparseVE:
|
|
case Backend::SparsePrivateUse1:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
static inline bool isSparseCsr(Backend b) {
|
|
switch (b) {
|
|
case Backend::SparseCsrCPU:
|
|
case Backend::SparseCsrCUDA:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
} // namespace c10
|