283 lines
8.7 KiB
C++
283 lines
8.7 KiB
C++
#include <torch/csrc/Device.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_arg_parser.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
|
|
#include <ATen/Device.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <structmember.h>
|
|
#include <cstring>
|
|
#include <limits>
|
|
#include <sstream>
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
PyObject* THPUpperModuleOfDevice = nullptr;
|
|
|
|
PyObject* THPDevice_New(const at::Device& device) {
|
|
auto type = (PyTypeObject*)&THPDeviceType;
|
|
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
|
if (!self)
|
|
throw python_error();
|
|
auto self_ = reinterpret_cast<THPDevice*>(self.get());
|
|
self_->device = device;
|
|
return self.release();
|
|
}
|
|
|
|
PyObject* THPDevice_repr(THPDevice* self) {
|
|
std::ostringstream oss;
|
|
oss << "device(type=\'" << self->device.type() << "\'";
|
|
if (self->device.has_index()) {
|
|
// `self->device.index()` returns uint8_t which is treated as ascii while
|
|
// printing, hence casting it to uint16_t.
|
|
// https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
|
|
oss << ", index=" << static_cast<uint16_t>(self->device.index());
|
|
}
|
|
oss << ")";
|
|
return THPUtils_packString(oss.str().c_str());
|
|
}
|
|
|
|
PyObject* THPDevice_str(THPDevice* self) {
|
|
std::ostringstream oss;
|
|
oss << self->device;
|
|
return THPUtils_packString(oss.str().c_str());
|
|
}
|
|
|
|
PyObject* THPDevice_pynew(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
static torch::PythonArgParser parser(
|
|
{"device(Device device)",
|
|
"device(c10::string_view type, int64_t? index=-1)"});
|
|
torch::ParsedArgs<2> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
if (r.has_torch_function()) {
|
|
return handle_torch_function(
|
|
r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
|
|
}
|
|
if (r.idx == 0) {
|
|
auto device = r.device(0);
|
|
return THPDevice_New(device);
|
|
} else if (r.idx == 1) {
|
|
auto as_device = r.device(0); // this works, because device can take strings
|
|
if (as_device.has_index()) {
|
|
auto device_type = r.string(0);
|
|
throw std::runtime_error(
|
|
"type (string) must not include an index because index "
|
|
"was passed explicitly: " +
|
|
device_type);
|
|
}
|
|
int64_t device_index = -1;
|
|
if (!r.isNone(1)) {
|
|
device_index = r.toInt64(1);
|
|
// -1 is allowed in ATen/C++, to mean the default device, but not in
|
|
// Python.
|
|
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
|
|
}
|
|
at::Device device(
|
|
as_device.type(), static_cast<c10::DeviceIndex>(device_index));
|
|
return THPDevice_New(device);
|
|
}
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
std::ostringstream oss;
|
|
oss << self->device.type();
|
|
return THPUtils_packString(oss.str().c_str());
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
if (self->device.has_index()) {
|
|
return THPUtils_packInt64(self->device.index());
|
|
} else {
|
|
Py_RETURN_NONE;
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static Py_ssize_t THPDevice_hash(THPDevice* self) {
|
|
HANDLE_TH_ERRORS
|
|
return static_cast<Py_ssize_t>(
|
|
std::hash<at::Device>{}(self->device) %
|
|
std::numeric_limits<Py_ssize_t>::max());
|
|
END_HANDLE_TH_ERRORS_RET(-1)
|
|
}
|
|
|
|
PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
|
|
HANDLE_TH_ERRORS
|
|
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
|
|
// Py_RETURN_NOTIMPLEMENTED not in python 2.
|
|
Py_INCREF(Py_NotImplemented);
|
|
return Py_NotImplemented;
|
|
}
|
|
THPDevice* da = reinterpret_cast<THPDevice*>(a);
|
|
THPDevice* db = reinterpret_cast<THPDevice*>(b);
|
|
|
|
switch (op) {
|
|
case Py_EQ:
|
|
if (da->device == db->device) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
case Py_NE:
|
|
if (da->device == db->device) {
|
|
Py_RETURN_FALSE;
|
|
} else {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
case Py_LT:
|
|
case Py_LE:
|
|
case Py_GT:
|
|
case Py_GE:
|
|
throw torch::TypeError("comparison not implemented");
|
|
default:
|
|
throw torch::TypeError("unexpected comparison op");
|
|
}
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
auto self = (THPDevice*)_self;
|
|
auto ret = THPObjectPtr{PyTuple_New(2)};
|
|
if (!ret)
|
|
throw python_error();
|
|
|
|
py::object torch_module = py::module::import("torch");
|
|
py::object torch_device = torch_module.attr("device");
|
|
PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
|
|
|
|
THPObjectPtr args;
|
|
std::ostringstream oss;
|
|
oss << self->device.type();
|
|
if (self->device.has_index()) {
|
|
args = THPObjectPtr{Py_BuildValue(
|
|
"(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
|
|
} else {
|
|
args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
|
|
}
|
|
if (!args)
|
|
throw python_error();
|
|
PyTuple_SET_ITEM(ret.get(), 1, args.release());
|
|
|
|
return ret.release();
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
|
|
HANDLE_TH_ERRORS
|
|
py::object mode = py::module::import("torch.utils._device")
|
|
.attr("DeviceContext")(py::handle(self));
|
|
at::impl::PythonTorchFunctionTLS::push_onto_stack(
|
|
std::make_shared<c10::SafePyObject>(
|
|
mode.release().ptr(), getPyInterpreter()));
|
|
// So that with torch.device('cuda') as dev: works
|
|
Py_INCREF(self);
|
|
return self;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
|
|
HANDLE_TH_ERRORS
|
|
at::impl::PythonTorchFunctionTLS::pop_stack();
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
py::object deco =
|
|
py::module::import("torch.utils._device").attr("device_decorator");
|
|
return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
|
|
.release()
|
|
.ptr();
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
typedef PyObject* (*getter)(PyObject*, void*);
|
|
|
|
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
static struct PyGetSetDef THPDevice_properties[] = {
|
|
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
|
|
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
|
|
{nullptr}};
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
static PyMethodDef THPDevice_methods[] = {
|
|
{"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
|
|
{"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
|
|
{"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
|
|
{nullptr} /* Sentinel */
|
|
};
|
|
|
|
PyTypeObject THPDeviceType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
|
|
sizeof(THPDevice), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
(reprfunc)THPDevice_repr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
(hashfunc)THPDevice_hash, /* tp_hash */
|
|
// TODO: We're not sure if this is a good idea or not, because making
|
|
// torch.device callable means that it will start returning true
|
|
// for callable() queries, and that is unexpected. We can always add
|
|
// this later, so for now, don't actually implement this
|
|
// THPDevice_call, /* tp_call */
|
|
nullptr, /* tp_call */
|
|
(reprfunc)THPDevice_str, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
THPDevice_methods, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
THPDevice_properties, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPDevice_pynew, /* tp_new */
|
|
};
|
|
|
|
void THPDevice_init(PyObject* module) {
|
|
if (PyType_Ready(&THPDeviceType) < 0) {
|
|
throw python_error();
|
|
}
|
|
Py_INCREF(&THPDeviceType);
|
|
THPUpperModuleOfDevice = module;
|
|
if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
|
|
throw python_error();
|
|
}
|
|
}
|