diff --git a/extra/gemm/cuda_matmul.py b/extra/gemm/cuda_matmul.py index 84590f83b..1aa089635 100644 --- a/extra/gemm/cuda_matmul.py +++ b/extra/gemm/cuda_matmul.py @@ -1,7 +1,7 @@ import os import numpy as np os.environ["CUDA"] = "1" -from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram +from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda FLOAT16 = True ACC_FLOAT16 = False @@ -21,7 +21,7 @@ c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32)) FLOPS = N*N*N*2 BW = N*N*3*4 -prog = CUDAProgram("wmma_example", f""" +prog = CUDAProgram("wmma_example", compile_cuda(f""" #include using namespace nvcuda; @@ -88,9 +88,10 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO }} }} }} -""") +""")) -tm = min([prog([(N//16*32)//4, (N//16)//4], [32, 1], a, b, c, wait=True) for _ in range(20)]) +global_size, local_size = [(N//16)//4, (N//16)//4], [32, 1, 1] +tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)]) print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s") np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2) \ No newline at end of file