86 lines
2.7 KiB
C++
86 lines
2.7 KiB
C++
#include <c10/core/ConstantSymNodeImpl.h>
|
|
#include <c10/core/SymBool.h>
|
|
#include <c10/core/SymNodeImpl.h>
|
|
#include <array>
|
|
#include <utility>
|
|
|
|
namespace c10 {
|
|
|
|
SymNode SymBool::toSymNodeImpl() const {
|
|
TORCH_CHECK(is_heap_allocated());
|
|
return SymNode::reclaim_copy(toSymNodeImplUnowned());
|
|
}
|
|
|
|
SymNode SymBool::wrap_node(const SymNode& base) const {
|
|
if (auto ma = maybe_as_bool()) {
|
|
return base->wrap_bool(*ma);
|
|
} else {
|
|
return toSymNodeImpl();
|
|
}
|
|
}
|
|
|
|
#define DEFINE_BINARY(API, OP, METHOD, RET) \
|
|
RET SymBool::API(const SymBool& sci) const { \
|
|
if (auto ma = maybe_as_bool()) { \
|
|
if (auto mb = sci.maybe_as_bool()) { \
|
|
return RET(OP(*ma, *mb)); \
|
|
} else { \
|
|
auto b = sci.toSymNodeImpl(); \
|
|
return RET(b->wrap_bool(*ma)->METHOD(b)); \
|
|
} \
|
|
} else { \
|
|
if (auto mb = sci.maybe_as_bool()) { \
|
|
auto a = toSymNodeImplUnowned(); \
|
|
return RET(a->METHOD(a->wrap_bool(*mb))); \
|
|
} else { \
|
|
return RET(toSymNodeImplUnowned()->METHOD(sci.toSymNodeImpl())); \
|
|
} \
|
|
} \
|
|
}
|
|
|
|
// clang-format off
|
|
DEFINE_BINARY(sym_and, std::logical_and<>(), sym_and, SymBool)
|
|
DEFINE_BINARY(sym_or, std::logical_or<>(), sym_or, SymBool)
|
|
// clang-format on
|
|
|
|
SymBool SymBool::sym_not() const {
|
|
if (auto ma = maybe_as_bool()) {
|
|
return SymBool(!*ma);
|
|
}
|
|
return SymBool(toSymNodeImpl()->sym_not());
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, const SymBool& s) {
|
|
if (auto ma = s.maybe_as_bool()) {
|
|
os << *ma;
|
|
} else {
|
|
os << s.toSymNodeImpl()->str();
|
|
}
|
|
return os;
|
|
}
|
|
|
|
bool SymBool::guard_bool(const char* file, int64_t line) const {
|
|
if (auto ma = maybe_as_bool()) {
|
|
return *ma;
|
|
}
|
|
SymNode a = toSymNodeImpl();
|
|
return a->guard_bool(file, line);
|
|
}
|
|
|
|
bool SymBool::expect_true(const char* file, int64_t line) const {
|
|
if (auto ma = maybe_as_bool()) {
|
|
return *ma;
|
|
}
|
|
SymNode a = toSymNodeImpl();
|
|
return a->expect_true(file, line);
|
|
}
|
|
|
|
bool SymBool::has_hint() const {
|
|
if (auto ma = maybe_as_bool()) {
|
|
return true;
|
|
}
|
|
return toSymNodeImpl()->has_hint();
|
|
}
|
|
|
|
} // namespace c10
|