52 lines
1.1 KiB
Common Lisp
52 lines
1.1 KiB
Common Lisp
// https://github.com/moskewcz/boda/issues/13
|
|
|
|
#define USE_FP16
|
|
|
|
#ifdef USE_FP16
|
|
#define up(x) x
|
|
#define down(x) x
|
|
#define xtype half8
|
|
#define skip 128
|
|
#else
|
|
#define up(x) convert_float8(x)
|
|
#define down(x) convert_half8(x)
|
|
#define xtype float8
|
|
#define skip 128
|
|
#endif
|
|
|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
__kernel void gemm(const int M, const int N, const int K,
|
|
global const half8* a, global const half8* b, global half8* c )
|
|
{
|
|
xtype c_r[8] = {0,0,0,0,0,0,0,0};
|
|
|
|
int const a_off_thr = get_global_id(0);
|
|
int const b_off_thr = get_global_id(1);
|
|
|
|
int a_off = a_off_thr;
|
|
int b_off = b_off_thr;
|
|
for( int k = 0; k < 1024; k += 1 ) {
|
|
xtype a_r = up(a[a_off]);
|
|
xtype b_r = up(b[b_off]);
|
|
|
|
c_r[0] += a_r.s0*b_r;
|
|
c_r[1] += a_r.s1*b_r;
|
|
c_r[2] += a_r.s2*b_r;
|
|
c_r[3] += a_r.s3*b_r;
|
|
c_r[4] += a_r.s4*b_r;
|
|
c_r[5] += a_r.s5*b_r;
|
|
c_r[6] += a_r.s6*b_r;
|
|
c_r[7] += a_r.s7*b_r;
|
|
|
|
a_off += skip;
|
|
b_off += skip;
|
|
}
|
|
|
|
int c_off = a_off_thr*1024 + b_off_thr;
|
|
for (int i = 0; i < 8; i++) {
|
|
c[c_off] = down(c_r[i]);
|
|
c_off += skip;
|
|
}
|
|
}
|
|
|