241 lines
6.4 KiB
C++
241 lines
6.4 KiB
C++
#include "caffe2/operators/cast_op.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <typename DstType, typename SrcType>
|
|
struct CastHelper {
|
|
static DstType call(SrcType data) {
|
|
return static_cast<DstType>(data);
|
|
}
|
|
};
|
|
|
|
template <typename SrcType>
|
|
struct CastHelper<std::string, SrcType> {
|
|
static std::string call(SrcType data) {
|
|
return caffe2::to_string(data);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
template <typename DstType, typename SrcType>
|
|
bool CastOp<CPUContext>::DoRunWithType() {
|
|
auto& input = Input(0);
|
|
|
|
auto* output = Output(0, input.sizes(), at::dtype<DstType>());
|
|
const auto* data = input.template data<SrcType>();
|
|
auto* out = output->template mutable_data<DstType>();
|
|
auto N = input.numel();
|
|
for (int64_t i = 0; i < N; ++i) {
|
|
out[i] = CastHelper<DstType, SrcType>::call(data[i]);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <>
|
|
void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
|
|
switch (to) {
|
|
case TensorProto_DataType_FLOAT:
|
|
// body_ = &CastOp::DoRunIncFp16WithDstType<float>;
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<float>;
|
|
break;
|
|
case TensorProto_DataType_INT32:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<int>;
|
|
break;
|
|
case TensorProto_DataType_BYTE:
|
|
LOG(FATAL) << "BYTE is deprecated";
|
|
break;
|
|
case TensorProto_DataType_STRING:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<std::string>;
|
|
break;
|
|
case TensorProto_DataType_BOOL:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<bool>;
|
|
break;
|
|
case TensorProto_DataType_UINT8:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
|
|
break;
|
|
case TensorProto_DataType_INT8:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
|
|
break;
|
|
case TensorProto_DataType_UINT16:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
|
|
break;
|
|
case TensorProto_DataType_INT16:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
|
|
break;
|
|
case TensorProto_DataType_INT64:
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
|
|
break;
|
|
case TensorProto_DataType_FLOAT16:
|
|
CAFFE_THROW("Casting to and from at::Half on CPU is not supported yet");
|
|
// break;
|
|
case TensorProto_DataType_DOUBLE:
|
|
// body_ = &CastOp::DoRunIncFp16WithDstType<double>;
|
|
body_ = &CastOp<CPUContext>::DoRunWithDstType<double>;
|
|
break;
|
|
case TensorProto_DataType_UNDEFINED:
|
|
CAFFE_THROW("Cast op must have 'to' argument of type DataType");
|
|
// break;
|
|
default:
|
|
CAFFE_THROW("Unexpected 'to' argument value: ", to);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
template <typename DstType>
|
|
bool CastOp<CPUContext>::DoRunWithDstType() {
|
|
return DispatchHelper<
|
|
TensorTypes<
|
|
float,
|
|
int32_t,
|
|
bool,
|
|
uint8_t,
|
|
int8_t,
|
|
uint16_t,
|
|
int16_t,
|
|
int64_t,
|
|
double>,
|
|
DstType>::call(this, Input(0));
|
|
}
|
|
|
|
REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
|
|
|
|
OPERATOR_SCHEMA(Cast)
|
|
.NumInputs(1)
|
|
.NumOutputs(1)
|
|
.TensorInferenceFunction([](const OperatorDef& def,
|
|
const vector<TensorShape>& in) {
|
|
ArgumentHelper helper(def);
|
|
vector<TensorShape> out;
|
|
out.push_back(in[0]);
|
|
out[0].set_data_type(cast::GetCastDataType(helper, "to"));
|
|
return out;
|
|
})
|
|
.SetDoc(R"DOC(
|
|
Casts the elements of a given input tensor to a data type specified by the `to`
|
|
argument and returns an output tensor of the same size in the converted type.
|
|
The `to` argument must be one of the data types specified in the *DataType*
|
|
enum field in the TensorProto message (see below). If the `to` argument is not
|
|
provided or is not one of the enumerated types in *DataType*, Caffe2 throws an
|
|
Enforce error.
|
|
|
|
NOTE: Casting from strings is not supported, and casting to strings is only
|
|
supported on CPU.
|
|
|
|
TensorProto *DataType* field:
|
|
```
|
|
message TensorProto {
|
|
...
|
|
enum DataType {
|
|
UNDEFINED = 0;
|
|
FLOAT = 1; // float
|
|
INT32 = 2; // int
|
|
BYTE = 3; // BYTE, when deserialized, is going to be restored as uint8.
|
|
STRING = 4; // string
|
|
BOOL = 5; // bool
|
|
UINT8 = 6; // uint8_t
|
|
INT8 = 7; // int8_t
|
|
UINT16 = 8; // uint16_t
|
|
INT16 = 9; // int16_t
|
|
INT64 = 10; // int64_t
|
|
FLOAT16 = 12; // at::Half
|
|
DOUBLE = 13; // double
|
|
}
|
|
```
|
|
|
|
Github Links:
|
|
|
|
- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/cast_op.cc
|
|
|
|
<details>
|
|
|
|
<summary> <b>Example</b> </summary>
|
|
|
|
**Code**
|
|
|
|
```
|
|
workspace.ResetWorkspace()
|
|
|
|
op = core.CreateOperator(
|
|
"Cast",
|
|
["X"],
|
|
["Y"],
|
|
to=2
|
|
)
|
|
|
|
workspace.FeedBlob("X", (np.random.rand(3,3)).astype(np.float32)*10)
|
|
print("X:", workspace.FetchBlob("X"))
|
|
workspace.RunOperatorOnce(op)
|
|
print("Y:", workspace.FetchBlob("Y"))
|
|
```
|
|
|
|
**Result**
|
|
|
|
```
|
|
X: [[9.436466 5.8529844 0.54932857]
|
|
[1.1583444 2.9936118 0.22950427]
|
|
[3.9143739 3.4040766 8.905341 ]]
|
|
Y: [[9 5 0]
|
|
[1 2 0]
|
|
[3 3 8]]
|
|
```
|
|
|
|
</details>
|
|
|
|
)DOC")
|
|
.Arg(
|
|
"to",
|
|
"*(type: int)* Data type to which the elements of the input tensor are "
|
|
"cast. Strictly must be one of the types from *DataType* enum in "
|
|
"TensorProto.")
|
|
.Input(0, "X", "*(type: Tensor)* Input tensor to be cast.")
|
|
.Output(
|
|
0,
|
|
"Y",
|
|
"*(type: Tensor`<'to' type>`)* Output tensor with the same shape as "
|
|
"input with type specified by the `to` argument.")
|
|
.InheritOnnxSchema();
|
|
|
|
// Some Casts are compatible with gradients, but for now we don't support it
|
|
// GRADIENT_NOT_IMPLEMENTED_YET(Cast);
|
|
|
|
class GetCastGradient : public GradientMakerBase {
|
|
using GradientMakerBase::GradientMakerBase;
|
|
vector<OperatorDef> GetGradientDefs() override {
|
|
|
|
vector<OperatorDef> defs = SingleGradientDef("Cast", "", vector<string>{GO(0)}, vector<string>{GI(0)});
|
|
|
|
// now modify the arguments in defs[0]
|
|
ArgumentHelper argsHelper(def_);
|
|
|
|
auto to_name = cast::GetCastDataType(argsHelper, "to");
|
|
|
|
CAFFE_ENFORCE(
|
|
argsHelper.HasSingleArgumentOfType<string>("from_type") ||
|
|
argsHelper.HasSingleArgumentOfType<int>("from_type"),
|
|
"Argument 'from_type' of type int or string"
|
|
" is required to get the gradient of CastOp");
|
|
|
|
auto from_name = cast::GetCastDataType(argsHelper, "from_type");
|
|
Argument *to = defs[0].add_arg();
|
|
to->set_name("to");
|
|
to->set_i(from_name);
|
|
|
|
Argument *from = defs[0].add_arg();
|
|
from->set_name("from_type");
|
|
from->set_i(to_name);
|
|
|
|
return defs;
|
|
}
|
|
|
|
bool CopyArguments() const override {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
REGISTER_GRADIENT(Cast, GetCastGradient);
|
|
|
|
|
|
|
|
|
|
} // namespace caffe2
|