1
0
Fork 0

8.46 TFLOPS

pull/570/head
George Hotz 2023-02-19 13:21:25 -08:00
parent 1ba847963d
commit bbfec2fde7
1 changed files with 24 additions and 35 deletions

View File

@ -27,10 +27,26 @@ prog = CLProgram("test", f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
using namespace metal;
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[thread_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
uint idx = gid.x/32;
uint pos_x = (idx%{N//32}) * 32;
uint pos_y = (idx/{N//32}) * 32;
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[thread_position_in_grid]], uint3 xid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint sidx [[simdgroup_index_in_threadgroup]]) {{
// 1-2 simd groups
//uint idx = gid.x/32;
//uint pos_x = (idx%{N//32}) * 32;
//uint pos_y = (idx/{N//32}) * 32;
// 4 simd groups
uint idx = gid.x/128;
uint pos_x = (idx%{N//64}) * 64;
uint pos_y = (idx/{N//64}) * 64;
pos_x += (sidx%2) * 32;
pos_y += (sidx/2) * 32;
// 16 simd groups (slow)
/*uint idx = gid.x/512;
uint pos_x = (idx%{N//128}) * 128;
uint pos_y = (idx/{N//128}) * 128;
pos_x += (sidx%4) * 32;
pos_y += (sidx/4) * 32;*/
simdgroup_float8x8 acc[4][4];
for (uint i = 0; i < 4; i++) {{
for (uint j = 0; j < 4; j++) {{
@ -41,12 +57,11 @@ kernel void test(device float *a, device const float *data1, device const float
simdgroup_float8x8 B[4];
data1 += pos_x * {N};
data2 += pos_y;
//__metal_get_null_simdgroup_event
//__metal_simdgroup_async_copy_2d
for (uint k = 0; k < {N}; k+=16) {{
for (uint k = 0; k < {N}; k+=8) {{
threadgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_load(A[0], data1, {N}, ulong2(k, 0));
simdgroup_load(A[1], data1, {N}, ulong2(k, 8));
threadgroup_barrier(mem_flags::mem_threadgroup);
simdgroup_load(A[2], data1, {N}, ulong2(k, 16));
simdgroup_load(A[3], data1, {N}, ulong2(k, 24));
simdgroup_load(B[0], data2, {N}, ulong2(0, k));
@ -70,32 +85,6 @@ kernel void test(device float *a, device const float *data1, device const float
simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]);
simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
simdgroup_load(A[0], data1, {N}, ulong2(k+8, 0));
simdgroup_load(A[1], data1, {N}, ulong2(k+8, 8));
simdgroup_load(A[2], data1, {N}, ulong2(k+8, 16));
simdgroup_load(A[3], data1, {N}, ulong2(k+8, 24));
simdgroup_load(B[0], data2, {N}, ulong2(0, k+8));
simdgroup_load(B[1], data2, {N}, ulong2(8, k+8));
simdgroup_load(B[2], data2, {N}, ulong2(16, k+8));
simdgroup_load(B[3], data2, {N}, ulong2(24, k+8));
simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]);
simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]);
simdgroup_multiply_accumulate(acc[0][2], A[2], B[0], acc[0][2]);
simdgroup_multiply_accumulate(acc[0][3], A[3], B[0], acc[0][3]);
simdgroup_multiply_accumulate(acc[1][0], A[0], B[1], acc[1][0]);
simdgroup_multiply_accumulate(acc[1][1], A[1], B[1], acc[1][1]);
simdgroup_multiply_accumulate(acc[1][2], A[2], B[1], acc[1][2]);
simdgroup_multiply_accumulate(acc[1][3], A[3], B[1], acc[1][3]);
simdgroup_multiply_accumulate(acc[2][0], A[0], B[2], acc[2][0]);
simdgroup_multiply_accumulate(acc[2][1], A[1], B[2], acc[2][1]);
simdgroup_multiply_accumulate(acc[2][2], A[2], B[2], acc[2][2]);
simdgroup_multiply_accumulate(acc[2][3], A[3], B[2], acc[2][3]);
simdgroup_multiply_accumulate(acc[3][0], A[0], B[3], acc[3][0]);
simdgroup_multiply_accumulate(acc[3][1], A[1], B[3], acc[3][1]);
simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
}}
for (uint i = 0; i < 4; i++) {{
for (uint j = 0; j < 4; j++) {{
@ -103,7 +92,7 @@ kernel void test(device float *a, device const float *data1, device const float
}}
}}
}}""")
tm = mb(lambda: prog([N*N//(2*4*4)], [2*32], a._cl, b._cl, c._cl))
tm = mb(lambda: prog([N*N//(2*4*4)], [4*32], a._cl, b._cl, c._cl))
na = a.toCPU().reshape(N,N)
comp = nb@nc
if N <= 32: