43 lines
845 B
C++
43 lines
845 B
C++
#include "caffe2/sgd/math_lp.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace internal {
|
|
|
|
template <>
|
|
void dot<float, float, float>(
|
|
const int N,
|
|
const float* x,
|
|
const float* y,
|
|
float* z,
|
|
CPUContext* ctx) {
|
|
math::Dot<float, CPUContext>(N, x, y, z, ctx);
|
|
}
|
|
|
|
template <>
|
|
void dot<float, at::Half, float>(
|
|
const int N,
|
|
const float* x,
|
|
const at::Half* y,
|
|
float* z,
|
|
CPUContext* ctx) {
|
|
#ifdef _MSC_VER
|
|
std::vector<float> tmp_y_vec(N);
|
|
float* tmp_y = tmp_y_vec.data();
|
|
#else
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
float tmp_y[N];
|
|
#endif
|
|
for (int i = 0; i < N; i++) {
|
|
#ifdef __F16C__
|
|
tmp_y[i] = _cvtss_sh(y[i], 0); // TODO: vectorize
|
|
#else
|
|
tmp_y[i] = y[i];
|
|
#endif
|
|
}
|
|
math::Dot<float, CPUContext>(N, x, tmp_y, z, ctx);
|
|
}
|
|
|
|
} // namespace internal
|
|
} // namespace caffe2
|