1
0
Fork 0

Update cuda_matmul.py (#2495)

pull/2484/head
Jake 2023-11-28 22:46:01 -05:00 committed by GitHub
parent cdc3b95729
commit 5588922884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 4 deletions

View File

@ -1,7 +1,7 @@
import os
import numpy as np
os.environ["CUDA"] = "1"
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
FLOAT16 = True
ACC_FLOAT16 = False
@ -21,7 +21,7 @@ c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32))
FLOPS = N*N*N*2
BW = N*N*3*4
prog = CUDAProgram("wmma_example", f"""
prog = CUDAProgram("wmma_example", compile_cuda(f"""
#include <mma.h>
using namespace nvcuda;
@ -88,9 +88,10 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO
}}
}}
}}
""")
"""))
tm = min([prog([(N//16*32)//4, (N//16)//4], [32, 1], a, b, c, wait=True) for _ in range(20)])
global_size, local_size = [(N//16)//4, (N//16)//4], [32, 1, 1]
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2)