392 lines
13 KiB
C++
392 lines
13 KiB
C++
#include <torch/csrc/TypeInfo.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_arg_parser.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
#include <torch/csrc/utils/tensor_dtypes.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <structmember.h>
|
|
#include <cstring>
|
|
#include <limits>
|
|
#include <sstream>
|
|
|
|
PyObject* THPFInfo_New(const at::ScalarType& type) {
|
|
auto finfo = (PyTypeObject*)&THPFInfoType;
|
|
auto self = THPObjectPtr{finfo->tp_alloc(finfo, 0)};
|
|
if (!self)
|
|
throw python_error();
|
|
auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
|
|
self_->type = c10::toRealValueType(type);
|
|
return self.release();
|
|
}
|
|
|
|
PyObject* THPIInfo_New(const at::ScalarType& type) {
|
|
auto iinfo = (PyTypeObject*)&THPIInfoType;
|
|
auto self = THPObjectPtr{iinfo->tp_alloc(iinfo, 0)};
|
|
if (!self)
|
|
throw python_error();
|
|
auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
|
|
self_->type = type;
|
|
return self.release();
|
|
}
|
|
|
|
PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
static torch::PythonArgParser parser({
|
|
"finfo(ScalarType type)",
|
|
"finfo()",
|
|
});
|
|
|
|
torch::ParsedArgs<1> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
TORCH_CHECK(r.idx < 2, "Not a type");
|
|
at::ScalarType scalar_type = at::ScalarType::Undefined;
|
|
if (r.idx == 1) {
|
|
scalar_type = torch::tensors::get_default_scalar_type();
|
|
// The default tensor type can only be set to a floating point type/
|
|
AT_ASSERT(at::isFloatingType(scalar_type));
|
|
} else {
|
|
scalar_type = r.scalartype(0);
|
|
if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError,
|
|
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
|
|
type->tp_name);
|
|
}
|
|
}
|
|
return THPFInfo_New(scalar_type);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
static torch::PythonArgParser parser({
|
|
"iinfo(ScalarType type)",
|
|
});
|
|
torch::ParsedArgs<1> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
TORCH_CHECK(r.idx == 0, "Not a type");
|
|
|
|
at::ScalarType scalar_type = r.scalartype(0);
|
|
if (scalar_type == at::ScalarType::Bool) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError, "torch.bool is not supported by torch.iinfo");
|
|
}
|
|
if (!at::isIntegralType(scalar_type, /*includeBool=*/false) &&
|
|
!at::isQIntType(scalar_type)) {
|
|
return PyErr_Format(
|
|
PyExc_TypeError,
|
|
"torch.iinfo() requires an integer input type. Use torch.finfo to handle '%s'",
|
|
type->tp_name);
|
|
}
|
|
return THPIInfo_New(scalar_type);
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPDTypeInfo_compare(THPDTypeInfo* a, THPDTypeInfo* b, int op) {
|
|
switch (op) {
|
|
case Py_EQ:
|
|
if (a->type == b->type) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
case Py_NE:
|
|
if (a->type != b->type) {
|
|
Py_RETURN_TRUE;
|
|
} else {
|
|
Py_RETURN_FALSE;
|
|
}
|
|
}
|
|
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
|
|
}
|
|
|
|
static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
uint64_t bits = elementSize(self->type) * 8;
|
|
return THPUtils_packUInt64(bits);
|
|
}
|
|
|
|
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"epsilon",
|
|
[] {
|
|
return PyFloat_FromDouble(
|
|
std::numeric_limits<
|
|
at::scalar_value_type<scalar_t>::type>::epsilon());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"max",
|
|
[] {
|
|
return PyFloat_FromDouble(
|
|
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"lowest",
|
|
[] {
|
|
return PyFloat_FromDouble(
|
|
std::numeric_limits<
|
|
at::scalar_value_type<scalar_t>::type>::lowest());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPIInfo_max(THPIInfo* self, void*) {
|
|
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
|
|
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
|
|
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
|
|
});
|
|
}
|
|
// Quantized Type
|
|
return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max", [] {
|
|
return THPUtils_packInt64(std::numeric_limits<underlying_t>::max());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPIInfo_min(THPIInfo* self, void*) {
|
|
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
|
|
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
|
|
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
|
|
});
|
|
}
|
|
// Quantized Type
|
|
return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min", [] {
|
|
return THPUtils_packInt64(std::numeric_limits<underlying_t>::lowest());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
|
|
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
|
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [&primary_name] {
|
|
return PyUnicode_FromString(primary_name.data());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"smallest",
|
|
[] {
|
|
return PyFloat_FromDouble(
|
|
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
|
|
});
|
|
}
|
|
|
|
static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
|
|
// see gh-70909, essentially the array_api prefers smallest_normal over tiny
|
|
return THPFInfo_smallest_normal(self, nullptr);
|
|
}
|
|
|
|
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"digits10",
|
|
[] {
|
|
return PyFloat_FromDouble(std::pow(
|
|
10,
|
|
-std::numeric_limits<
|
|
at::scalar_value_type<scalar_t>::type>::digits10));
|
|
});
|
|
}
|
|
|
|
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
|
|
auto primary_name = torch::utils::getDtypeNames(self->type).first;
|
|
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
|
|
at::kHalf,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
self->type,
|
|
"dtype",
|
|
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });
|
|
}
|
|
|
|
PyObject* THPFInfo_str(THPFInfo* self) {
|
|
std::ostringstream oss;
|
|
oss << "finfo(resolution="
|
|
<< PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
|
|
oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
|
|
oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
|
|
oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
|
|
oss << ", smallest_normal="
|
|
<< PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr));
|
|
oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
|
|
oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")";
|
|
|
|
return THPUtils_packString(oss.str().c_str());
|
|
}
|
|
|
|
PyObject* THPIInfo_str(THPIInfo* self) {
|
|
std::ostringstream oss;
|
|
|
|
oss << "iinfo(min=" << PyLong_AsDouble(THPIInfo_min(self, nullptr));
|
|
oss << ", max=" << PyLong_AsDouble(THPIInfo_max(self, nullptr));
|
|
oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")";
|
|
|
|
return THPUtils_packString(oss.str().c_str());
|
|
}
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
|
|
static struct PyGetSetDef THPFInfo_properties[] = {
|
|
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
|
|
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
|
|
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
|
|
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
|
|
{"smallest_normal",
|
|
(getter)THPFInfo_smallest_normal,
|
|
nullptr,
|
|
nullptr,
|
|
nullptr},
|
|
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
|
|
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
|
|
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
|
|
{nullptr}};
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
|
|
static PyMethodDef THPFInfo_methods[] = {
|
|
{nullptr} /* Sentinel */
|
|
};
|
|
|
|
PyTypeObject THPFInfoType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch.finfo", /* tp_name */
|
|
sizeof(THPFInfo), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
(reprfunc)THPFInfo_str, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
(reprfunc)THPFInfo_str, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
THPFInfo_methods, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
THPFInfo_properties, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPFInfo_pynew, /* tp_new */
|
|
};
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
|
|
static struct PyGetSetDef THPIInfo_properties[] = {
|
|
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
|
|
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
|
|
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
|
|
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
|
|
{nullptr}};
|
|
|
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
|
|
static PyMethodDef THPIInfo_methods[] = {
|
|
{nullptr} /* Sentinel */
|
|
};
|
|
|
|
PyTypeObject THPIInfoType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch.iinfo", /* tp_name */
|
|
sizeof(THPIInfo), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
(reprfunc)THPIInfo_str, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
(reprfunc)THPIInfo_str, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
THPIInfo_methods, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
THPIInfo_properties, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPIInfo_pynew, /* tp_new */
|
|
};
|
|
|
|
void THPDTypeInfo_init(PyObject* module) {
|
|
if (PyType_Ready(&THPFInfoType) < 0) {
|
|
throw python_error();
|
|
}
|
|
Py_INCREF(&THPFInfoType);
|
|
if (PyModule_AddObject(module, "finfo", (PyObject*)&THPFInfoType) != 0) {
|
|
throw python_error();
|
|
}
|
|
if (PyType_Ready(&THPIInfoType) < 0) {
|
|
throw python_error();
|
|
}
|
|
Py_INCREF(&THPIInfoType);
|
|
if (PyModule_AddObject(module, "iinfo", (PyObject*)&THPIInfoType) != 0) {
|
|
throw python_error();
|
|
}
|
|
}
|