66 lines
1.5 KiB
C++
66 lines
1.5 KiB
C++
#include "caffe2/operators/assert_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
REGISTER_CPU_OPERATOR(Assert, AssertOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(Assert)
|
|
.NumInputs(1)
|
|
.NumOutputs(0)
|
|
.SetDoc(R"DOC(
|
|
Takes in a tensor of type *bool*, *int*, *long*, or *long long* and checks if all values are True when coerced into a boolean. In other words, for non-bool types this asserts that all values in the tensor are non-zero. If a value is False after coerced into a boolean, the operator throws an error. Else, if all values are True, nothing is returned. For tracability, a custom error message can be set using the `error_msg` argument.
|
|
|
|
Github Links:
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/assert_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Assert",
|
|
["A"],
|
|
[],
|
|
error_msg="Failed assertion from Assert operator"
|
|
)
|
|
|
|
workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.int32))
|
|
print("A:", workspace.FetchBlob("A"))
|
|
try:
|
|
workspace.RunOperatorOnce(op)
|
|
except RuntimeError:
|
|
print("Assertion Failed!")
|
|
else:
|
|
print("Assertion Passed!")
|
|
|
|
```
|
|
|
|
**Result**
|
|
|
|
```
|
|
|
|
A:
|
|
[[7 5 6]
|
|
[1 2 4]
|
|
[5 3 7]]
|
|
Assertion Passed!
|
|
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Arg(
|
|
"error_msg",
|
|
"(*string*): custom error message to be thrown when the input does not pass assertion",
|
|
false)
|
|
.Input(0,"X","(*Tensor*): input tensor");
|
|
|
|
} // namespace caffe2
|