openpilot/selfdrive/modeld/thneed/debug/microbenchmark/gemm.cl

52 lines
1.1 KiB
Common Lisp

// https://github.com/moskewcz/boda/issues/13
#define USE_FP16
#ifdef USE_FP16
#define up(x) x
#define down(x) x
#define xtype half8
#define skip 128
#else
#define up(x) convert_float8(x)
#define down(x) convert_half8(x)
#define xtype float8
#define skip 128
#endif
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void gemm(const int M, const int N, const int K,
global const half8* a, global const half8* b, global half8* c )
{
xtype c_r[8] = {0,0,0,0,0,0,0,0};
int const a_off_thr = get_global_id(0);
int const b_off_thr = get_global_id(1);
int a_off = a_off_thr;
int b_off = b_off_thr;
for( int k = 0; k < 1024; k += 1 ) {
xtype a_r = up(a[a_off]);
xtype b_r = up(b[b_off]);
c_r[0] += a_r.s0*b_r;
c_r[1] += a_r.s1*b_r;
c_r[2] += a_r.s2*b_r;
c_r[3] += a_r.s3*b_r;
c_r[4] += a_r.s4*b_r;
c_r[5] += a_r.s5*b_r;
c_r[6] += a_r.s6*b_r;
c_r[7] += a_r.s7*b_r;
a_off += skip;
b_off += skip;
}
int c_off = a_off_thr*1024 + b_off_thr;
for (int i = 0; i < 8; i++) {
c[c_off] = down(c_r[i]);
c_off += skip;
}
}