import time
import numpy as np
from tinygrad.helpers import dtypes, getenv, prod, flat_mv
from tinygrad.runtime.ops_hip import HIPAllocator, HIPProgram, compile_hip
# AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048
# gets ~100
# hipExtModuleLaunchKernel ( 0x0x16ccde0, 2048, 16, 1, 128, 1, 1,
# 161.60 us = 106.31 TFLOPS
# with --batch_count 8 / 1.258128 ms / (8*2048*2048*2048*2)/(1.258128)*1e-9 / 109.24 TFLOPS
# we only get ~53
# KY=2 KX=2 N=2048 python3 extra/gemm/hip_matmul.py
# 4194304 324.76 us, would be 52899.88 GFLOPS matmul, 154.98 GB/s
N = getenv("N", 2048)
KX = getenv("KX", 4)
KY = getenv("KY", 4)
assert N % (16 * KX) == 0, f"N must be multiple of {16*KX}"
assert N % (16 * KY) == 0, f"N must be multiple of {16*KY}"
FLOPS = N * N * N * 2
BW = N * N * 3 * 4
# Can HIPAllocator initialized as device=0 by default?
device = 0
hipallocator = HIPAllocator(device)
a = hipallocator.alloc(N * N * 4)
b = hipallocator.alloc(N * N * 2)
c = hipallocator.alloc(N * N * 2)
na = np.empty(N * N, np.float32)
nb = (
.standard_normal(size=(N, N), dtype=np.float32)
nc = (
.standard_normal(size=(N, N), dtype=np.float32)
hipallocator.copyin(b, bytearray(nb))
hipallocator.copyin(c, bytearray(nc))
lib = compile_hip(
#define F32
typedef float float8 __attribute__((ext_vector_type(8)));
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a, __half* b) {{
const int gx = blockIdx.x*2 + threadIdx.y;
const int gy = blockIdx.y*2 + threadIdx.z;
const int lIdx = threadIdx.x;
const int lane = lIdx%16;
c += gx*{KX*16}*{N} + gy*{KY*16} + (lIdx/16)*{N} + lane;
a += gx*{KX*16}*{N};
b += gy*{KY*16};
half16 a_frag[{KX}];
half16 b_frag[{KY}];
#ifdef F32
float8 c_frag[{KY}][{KX}] = {{}};
half16 c_frag[{KY}][{KX}] = {{}};
for (int k = 0; k < {N}; k += 16) {{
for (int ele = 0; ele < 16; ++ele) {{
for (int x = 0; x < {KX}; x++) {{
a_frag[x][ele] = a[(k+ele) + x*{16*N} + {N}*lane];
for (int ele = 0; ele < 16; ++ele) {{
for (int y = 0; y < {KY}; y++) {{
b_frag[y][ele] = b[(k+ele)*{N} + y*16 + lane];
for (int y = 0; y < {KY}; y++) {{
for (int x = 0; x < {KX}; x++) {{
#ifdef F32
c_frag[y][x] = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag[x], b_frag[y], c_frag[y][x]);
c_frag[y][x] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a_frag[x], b_frag[y], c_frag[y][x], false);
for (int ele = 0; ele < 8; ++ele) {{
for (int y = 0; y < {KY}; y++) {{
for (int x = 0; x < {KX}; x++) {{
#ifdef F32
c[ele*{2*N} + y*16 + x*{16*N}] = c_frag[y][x][ele];
c[ele*{2*N} + y*16 + x*{16*N}] = c_frag[y][x][ele*2];
prog = HIPProgram(device, "test", lib)
def timeit(fxn):
st = time.perf_counter()
et = fxn()
ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead
# print(f"{ret*1e6:.2f} us")
return et
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
"global/local size",
f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}",
tm = min(
lambda: prog(
a, b, c, global_size=global_size, local_size=local_size, wait=True
for _ in range(1000)
hipallocator.copyout(flat_mv(na.data), a)
na = na.reshape(N, N)
comp = nb.astype(np.float32) @ nc.astype(np.float32)
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)