1
0
Fork 0

metal_matmul: bw and torch sync

pull/726/head
George Hotz 2023-03-23 08:02:04 -07:00
parent bd6c3c31a9
commit 68e45fca18
1 changed files with 4 additions and 3 deletions

View File

@ -13,6 +13,7 @@ b = RawMetalBuffer.fromCPU(nb)
c = RawMetalBuffer.fromCPU(nc)
FLOPS = N*N*N*2
BW = N*N*3
prog = MetalProgram("test", f"""
#include <metal_stdlib>
@ -89,17 +90,17 @@ comp = nb@nc
if N <= 32:
print(na)
print(comp)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul")
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na, comp, atol=1e-3)
import time, torch
import time, torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = b@c
a.cpu()
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul in torch")