558 lines
23 KiB
C++
558 lines
23 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/Deprecated.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Float8_e4m3fn.h>
|
|
#include <c10/util/Float8_e5m2.h>
|
|
#include <c10/util/Half.h>
|
|
#include <c10/util/bits.h>
|
|
#include <c10/util/complex.h>
|
|
#include <c10/util/qint32.h>
|
|
#include <c10/util/qint8.h>
|
|
#include <c10/util/quint2x4.h>
|
|
#include <c10/util/quint4x2.h>
|
|
#include <c10/util/quint8.h>
|
|
|
|
#include <complex>
|
|
#include <cstdint>
|
|
#include <ostream>
|
|
|
|
namespace c10 {
|
|
|
|
// For the macros below:
|
|
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
|
|
// with complete information, you probably want one of the
|
|
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
|
|
// macros below, which are designed to behave similarly to the Dispatch macros
|
|
// with the same name.
|
|
|
|
// NB: Order matters for this macro; it is relied upon in
|
|
// _promoteTypesLookup and the serialization format.
|
|
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
|
_(uint8_t, Byte) /* 0 */ \
|
|
_(int8_t, Char) /* 1 */ \
|
|
_(int16_t, Short) /* 2 */ \
|
|
_(int, Int) /* 3 */ \
|
|
_(int64_t, Long) /* 4 */ \
|
|
_(at::Half, Half) /* 5 */ \
|
|
_(float, Float) /* 6 */ \
|
|
_(double, Double) /* 7 */ \
|
|
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
|
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
|
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
|
_(bool, Bool) /* 11 */ \
|
|
_(c10::qint8, QInt8) /* 12 */ \
|
|
_(c10::quint8, QUInt8) /* 13 */ \
|
|
_(c10::qint32, QInt32) /* 14 */ \
|
|
_(at::BFloat16, BFloat16) /* 15 */ \
|
|
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
|
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
|
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
|
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
|
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
|
_(c10::bits8, Bits8) /* 21 */ \
|
|
_(c10::bits16, Bits16) /* 22 */ \
|
|
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
|
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */
|
|
|
|
// If you want to support ComplexHalf for real, add ComplexHalf
|
|
// into this macro (and change the name). But beware: convert()
|
|
// doesn't work for all the conversions you need...
|
|
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(at::Half, Half) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(c10::complex<float>, ComplexFloat) \
|
|
_(c10::complex<double>, ComplexDouble) \
|
|
_(bool, Bool) \
|
|
_(at::BFloat16, BFloat16) \
|
|
_(at::Float8_e5m2, Float8_e5m2) \
|
|
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(at::Half, Half) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(c10::complex<c10::Half>, ComplexHalf) \
|
|
_(c10::complex<float>, ComplexFloat) \
|
|
_(c10::complex<double>, ComplexDouble) \
|
|
_(bool, Bool) \
|
|
_(at::BFloat16, BFloat16) \
|
|
_(at::Float8_e5m2, Float8_e5m2) \
|
|
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
|
|
|
enum class ScalarType : int8_t {
|
|
#define DEFINE_ENUM(_1, n) n,
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
|
|
#undef DEFINE_ENUM
|
|
Undefined,
|
|
NumOptions
|
|
};
|
|
|
|
constexpr uint16_t NumScalarTypes =
|
|
static_cast<uint16_t>(ScalarType::NumOptions);
|
|
|
|
namespace impl {
|
|
|
|
// These are used to map ScalarTypes to C++ types.
|
|
|
|
template <c10::ScalarType N>
|
|
struct ScalarTypeToCPPType;
|
|
|
|
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
|
template <> \
|
|
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
|
using type = cpp_type; \
|
|
\
|
|
/* This is a workaround for the CUDA bug which prevents */ \
|
|
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
|
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
|
/* can't pick between at::detail and at::cuda::detail. */ \
|
|
/* For repro example, please see: */ \
|
|
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
|
/* TODO: remove once the bug is fixed. */ \
|
|
static type t; \
|
|
};
|
|
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
|
|
|
#undef SPECIALIZE_ScalarTypeToCPPType
|
|
|
|
template <c10::ScalarType N>
|
|
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
|
|
|
} // namespace impl
|
|
|
|
template <typename T>
|
|
struct CppTypeToScalarType;
|
|
|
|
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
|
template <> \
|
|
struct CppTypeToScalarType<cpp_type> \
|
|
: std:: \
|
|
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
|
};
|
|
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|
|
|
#undef SPECIALIZE_CppTypeToScalarType
|
|
|
|
#define AT_FORALL_INT_TYPES(_) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES(_) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE>::t), \
|
|
SCALARTYPE)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE1>::t), \
|
|
SCALARTYPE1) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE2>::t), \
|
|
SCALARTYPE2)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE1>::t), \
|
|
SCALARTYPE1) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE2>::t), \
|
|
SCALARTYPE2) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE3>::t), \
|
|
SCALARTYPE3)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_AND4( \
|
|
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, _) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE1>::t), \
|
|
SCALARTYPE1) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE2>::t), \
|
|
SCALARTYPE2) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE3>::t), \
|
|
SCALARTYPE3) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE4>::t), \
|
|
SCALARTYPE4)
|
|
|
|
#define AT_FORALL_SCALAR_TYPES_AND5( \
|
|
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, _) \
|
|
_(uint8_t, Byte) \
|
|
_(int8_t, Char) \
|
|
_(int16_t, Short) \
|
|
_(int, Int) \
|
|
_(int64_t, Long) \
|
|
_(float, Float) \
|
|
_(double, Double) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE1>::t), \
|
|
SCALARTYPE1) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE2>::t), \
|
|
SCALARTYPE2) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE3>::t), \
|
|
SCALARTYPE3) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE4>::t), \
|
|
SCALARTYPE4) \
|
|
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
|
::c10::ScalarType::SCALARTYPE5>::t), \
|
|
SCALARTYPE5)
|
|
|
|
#define AT_FORALL_QINT_TYPES(_) \
|
|
_(c10::qint8, QInt8) \
|
|
_(c10::quint8, QUInt8) \
|
|
_(c10::qint32, QInt32) \
|
|
_(c10::quint4x2, QUInt4x2) \
|
|
_(c10::quint2x4, QUInt2x4)
|
|
|
|
#define AT_FORALL_COMPLEX_TYPES(_) \
|
|
_(c10::complex<float>, ComplexFloat) \
|
|
_(c10::complex<double>, ComplexDouble)
|
|
|
|
#define DEFINE_CONSTANT(_, name) \
|
|
constexpr ScalarType k##name = ScalarType::name;
|
|
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
|
#undef DEFINE_CONSTANT
|
|
|
|
static inline const char* toString(ScalarType t) {
|
|
#define DEFINE_CASE(_, name) \
|
|
case ScalarType::name: \
|
|
return #name;
|
|
|
|
switch (t) {
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
|
default:
|
|
return "UNKNOWN_SCALAR";
|
|
}
|
|
#undef DEFINE_CASE
|
|
}
|
|
|
|
static inline size_t elementSize(ScalarType t) {
|
|
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
|
case ScalarType::name: \
|
|
return sizeof(ctype);
|
|
|
|
switch (t) {
|
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
|
|
default:
|
|
TORCH_CHECK(false, "Unknown ScalarType");
|
|
}
|
|
#undef CASE_ELEMENTSIZE_CASE
|
|
}
|
|
|
|
static inline bool isIntegralType(ScalarType t, bool includeBool) {
|
|
bool isIntegral =
|
|
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
|
|
t == ScalarType::Long || t == ScalarType::Short);
|
|
|
|
return isIntegral || (includeBool && t == ScalarType::Bool);
|
|
}
|
|
|
|
C10_DEPRECATED_MESSAGE(
|
|
"isIntegralType is deprecated. Please use the overload with 'includeBool' parameter instead.")
|
|
static inline bool isIntegralType(ScalarType t) {
|
|
return isIntegralType(t, /*includeBool=*/false);
|
|
}
|
|
|
|
static inline bool isFloatingType(ScalarType t) {
|
|
return (
|
|
t == ScalarType::Double || t == ScalarType::Float ||
|
|
t == ScalarType::Half || t == ScalarType::BFloat16 ||
|
|
t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn);
|
|
}
|
|
|
|
static inline bool isFloat8Type(ScalarType t) {
|
|
return t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e5m2;
|
|
}
|
|
static inline bool isReducedFloatingType(ScalarType t) {
|
|
return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t);
|
|
}
|
|
|
|
static inline bool isComplexType(ScalarType t) {
|
|
return (
|
|
t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
|
|
t == ScalarType::ComplexDouble);
|
|
}
|
|
|
|
static inline bool isQIntType(ScalarType t) {
|
|
// Don't forget to extend this when adding new QInt types
|
|
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
|
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
|
|
t == ScalarType::QUInt2x4;
|
|
}
|
|
|
|
static inline bool isBitsType(ScalarType t) {
|
|
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
|
|
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
|
|
t == ScalarType::Bits16;
|
|
}
|
|
|
|
static inline ScalarType toQIntType(ScalarType t) {
|
|
switch (t) {
|
|
case ScalarType::Byte:
|
|
return ScalarType::QUInt8;
|
|
case ScalarType::Char:
|
|
return ScalarType::QInt8;
|
|
case ScalarType::Int:
|
|
return ScalarType::QInt32;
|
|
default:
|
|
return t;
|
|
}
|
|
}
|
|
|
|
static inline ScalarType toUnderlying(ScalarType t) {
|
|
switch (t) {
|
|
case ScalarType::QUInt8:
|
|
return ScalarType::Byte;
|
|
case ScalarType::QInt8:
|
|
return ScalarType::Char;
|
|
case ScalarType::QInt32:
|
|
return ScalarType::Int;
|
|
case ScalarType::QUInt4x2:
|
|
return ScalarType::Byte;
|
|
case ScalarType::QUInt2x4:
|
|
return ScalarType::Byte;
|
|
default:
|
|
return t;
|
|
}
|
|
}
|
|
|
|
static inline bool isSignedType(ScalarType t) {
|
|
TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types");
|
|
#define CASE_SIGNED(ctype, name) \
|
|
case ScalarType::name: \
|
|
return std::numeric_limits<ctype>::is_signed;
|
|
|
|
switch (t) {
|
|
case ScalarType::Bits1x8:
|
|
case ScalarType::Bits2x4:
|
|
case ScalarType::Bits4x2:
|
|
case ScalarType::Bits8:
|
|
case ScalarType::Bits16:
|
|
TORCH_CHECK(false, "Bits types are undefined");
|
|
case ScalarType::ComplexHalf:
|
|
case ScalarType::ComplexFloat:
|
|
case ScalarType::ComplexDouble:
|
|
return true;
|
|
AT_FORALL_SCALAR_TYPES_AND5(
|
|
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
|
|
default:
|
|
TORCH_CHECK(false, "Unknown ScalarType");
|
|
}
|
|
#undef CASE_SIGNED
|
|
}
|
|
|
|
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
|
|
return type == toUnderlying(qtype);
|
|
}
|
|
|
|
static inline ScalarType toRealValueType(ScalarType t) {
|
|
switch (t) {
|
|
case ScalarType::ComplexHalf:
|
|
return ScalarType::Half;
|
|
case ScalarType::ComplexFloat:
|
|
return ScalarType::Float;
|
|
case ScalarType::ComplexDouble:
|
|
return ScalarType::Double;
|
|
default:
|
|
return t;
|
|
}
|
|
}
|
|
|
|
static inline ScalarType toComplexType(ScalarType t) {
|
|
switch (t) {
|
|
case ScalarType::BFloat16:
|
|
// BFloat16 has range equivalent to Float,
|
|
// so we map it to ComplexFloat.
|
|
return ScalarType::ComplexFloat;
|
|
case ScalarType::Half:
|
|
return ScalarType::ComplexHalf;
|
|
case ScalarType::Float:
|
|
return ScalarType::ComplexFloat;
|
|
case ScalarType::Double:
|
|
return ScalarType::ComplexDouble;
|
|
case ScalarType::ComplexHalf:
|
|
return ScalarType::ComplexHalf;
|
|
case ScalarType::ComplexFloat:
|
|
return ScalarType::ComplexFloat;
|
|
case ScalarType::ComplexDouble:
|
|
return ScalarType::ComplexDouble;
|
|
default:
|
|
TORCH_CHECK(false, "Unknown Complex ScalarType for ", t);
|
|
}
|
|
}
|
|
|
|
// see tensor_attributes.rst for detailed explanation and examples
|
|
// of casting rules.
|
|
static inline bool canCast(const ScalarType from, const ScalarType to) {
|
|
// We disallow complex -> non complex, e.g., float_tensor *= complex is
|
|
// disallowed.
|
|
if (isComplexType(from) && !isComplexType(to)) {
|
|
return false;
|
|
}
|
|
// We disallow float -> integral, e.g., int_tensor *= float is disallowed.
|
|
if (isFloatingType(from) && isIntegralType(to, false)) {
|
|
return false;
|
|
}
|
|
|
|
// Treat bool as a distinct "category," to be consistent with type promotion
|
|
// rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same
|
|
// category as `bool_tensor`, we would not promote. Differing categories
|
|
// implies `bool_tensor += 5` is disallowed.
|
|
//
|
|
// NB: numpy distinguishes "unsigned" as a category to get the desired
|
|
// `bool_tensor + 5 -> int64_tensor` behavior. We don't, because:
|
|
// * We don't want the performance hit of checking the runtime sign of
|
|
// Scalars.
|
|
// * `uint8_tensor + 5 -> int64_tensor` would be undesirable.
|
|
if (from != ScalarType::Bool && to == ScalarType::Bool) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|
// This is generated according to NumPy's promote_types
|
|
constexpr auto u1 = ScalarType::Byte;
|
|
constexpr auto i1 = ScalarType::Char;
|
|
constexpr auto i2 = ScalarType::Short;
|
|
constexpr auto i4 = ScalarType::Int;
|
|
constexpr auto i8 = ScalarType::Long;
|
|
constexpr auto f2 = ScalarType::Half;
|
|
constexpr auto f4 = ScalarType::Float;
|
|
constexpr auto f8 = ScalarType::Double;
|
|
constexpr auto c2 = ScalarType::ComplexHalf;
|
|
constexpr auto c4 = ScalarType::ComplexFloat;
|
|
constexpr auto c8 = ScalarType::ComplexDouble;
|
|
constexpr auto b1 = ScalarType::Bool;
|
|
constexpr auto bf = ScalarType::BFloat16;
|
|
constexpr auto b8 = ScalarType::Float8_e5m2;
|
|
constexpr auto h8 = ScalarType::Float8_e4m3fn;
|
|
constexpr auto ud = ScalarType::Undefined;
|
|
if (a == ud || b == ud) {
|
|
return ScalarType::Undefined;
|
|
}
|
|
|
|
// For QInt types, we only allow exact match
|
|
if (isQIntType(a) && a == b) {
|
|
return a;
|
|
}
|
|
|
|
if (isQIntType(a) || isQIntType(b)) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: ",
|
|
toString(a),
|
|
" ",
|
|
toString(b));
|
|
}
|
|
|
|
if (isBitsType(a) && a == b) {
|
|
return a;
|
|
} else if (isBitsType(a) || isBitsType(b)) {
|
|
return ScalarType::Undefined;
|
|
}
|
|
|
|
// Ignore the 5 bits types, since they are handled by the if statement
|
|
// above and do not participate in type promotion. The `5` value has to
|
|
// be consistent with the number of the unique `c10::bits*` types that
|
|
// exist.
|
|
const int NUM_PROMOTE_TYPES = static_cast<int>(ScalarType::NumOptions) - 5;
|
|
|
|
// this matrix has to be consistent with
|
|
// AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we
|
|
// are not sure about the correct value for type promotion.
|
|
// clang-format off
|
|
static constexpr ScalarType _promoteTypesLookup[
|
|
NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
|
|
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf b8 h8*/
|
|
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, b8, h8},
|
|
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, b8, h8},
|
|
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, b8, h8},
|
|
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, b8, h8},
|
|
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, b8, h8},
|
|
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, f4, f4},
|
|
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, f4, f4},
|
|
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, f8, f8},
|
|
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, c4, c4},
|
|
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, c4, c4},
|
|
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, c8, c8},
|
|
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, b8, h8},
|
|
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, bf, bf},
|
|
/* b8 */ {b8, b8, b8, b8, b8, f4, f4, f8, c4, c4, c8, b8, ud, ud, ud, bf, b8, ud},
|
|
/* h8 */ {h8, h8, h8, h8, h8, f4, f4, f8, c4, c4, c8, h8, ud, ud, ud, bf, ud, h8},
|
|
};
|
|
// clang-format on
|
|
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
|
}
|
|
|
|
inline std::ostream& operator<<(
|
|
std::ostream& stream,
|
|
at::ScalarType scalar_type) {
|
|
return stream << toString(scalar_type);
|
|
}
|
|
|
|
} // namespace c10
|