pytorch/caffe2/sgd/math_lp.cc

43 lines
845 B
C++

#include "caffe2/sgd/math_lp.h"
namespace caffe2 {
namespace internal {
template <>
void dot<float, float, float>(
const int N,
const float* x,
const float* y,
float* z,
CPUContext* ctx) {
math::Dot<float, CPUContext>(N, x, y, z, ctx);
}
template <>
void dot<float, at::Half, float>(
const int N,
const float* x,
const at::Half* y,
float* z,
CPUContext* ctx) {
#ifdef _MSC_VER
std::vector<float> tmp_y_vec(N);
float* tmp_y = tmp_y_vec.data();
#else
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
float tmp_y[N];
#endif
for (int i = 0; i < N; i++) {
#ifdef __F16C__
tmp_y[i] = _cvtss_sh(y[i], 0); // TODO: vectorize
#else
tmp_y[i] = y[i];
#endif
}
math::Dot<float, CPUContext>(N, x, tmp_y, z, ctx);
}
} // namespace internal
} // namespace caffe2