metal_matmul: bw and torch sync
parent
bd6c3c31a9
commit
68e45fca18
|
@ -13,6 +13,7 @@ b = RawMetalBuffer.fromCPU(nb)
|
|||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3
|
||||
|
||||
prog = MetalProgram("test", f"""
|
||||
#include <metal_stdlib>
|
||||
|
@ -89,17 +90,17 @@ comp = nb@nc
|
|||
if N <= 32:
|
||||
print(na)
|
||||
print(comp)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul")
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||
|
||||
import time, torch
|
||||
import time, torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b@c
|
||||
a.cpu()
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul in torch")
|
||||
|
|
Loading…
Reference in New Issue