1
0
Fork 0
tinygrab/extra/thneed.py

406 lines
16 KiB
Python

# this can be constructed from a cl_cache or loaded from a thneed file
import time
import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.device import Device
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO
CL = Device["GPU"]
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
class Thneed:
def __init__(self, cl_cache=[], inputs={}):
self.cl_cache, self.inputs = cl_cache[:], inputs
self.gobj = 0
# build graph
# NOTE: if CLCACHE=1, this is wrong!
nodes = defaultdict(lambda: {"in_edges": [], "out_edges": []})
for _, args in self.cl_cache:
# output is always the first parameter
for a in args[3:]:
nodes[a]["out_edges"].append(args[2])
nodes[args[2]]["in_edges"].append(a)
# get buffers to save
self.buffers_to_save = set()
self.outputs = []
for n in nodes.keys():
if len(nodes[n]["in_edges"]) == 0:
self.buffers_to_save.add(n)
if len(nodes[n]["out_edges"]) == 0:
self.outputs.append(n)
fake_inputs = []
for k, n in self.inputs.items():
if n in self.buffers_to_save:
self.buffers_to_save.remove(n)
else:
print(f"WARNING: {k} was not a used input, removing it")
fake_inputs.append(k)
for k in fake_inputs:
del self.inputs[k]
def load(self, input_fn):
float32 = not FLOAT16
mf = cl.mem_flags
image_fmt = cl.ImageFormat(
cl.channel_order.RGBA,
cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT,
)
image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
with open(input_fn, "rb") as f:
json_len = struct.unpack("I", f.read(4))[0]
jdat = json.loads(f.read(json_len).decode("latin_1"))
weights = f.read()
# load in the buffers
bufs = {"\x00\x00\x00\x00\x00\x00\x00\x00": None}
bufs_loaded = {}
ptr = 0
for o in jdat["objects"]:
# print(o)
if o["needs_load"]:
nptr = ptr + o["size"]
o["data"] = weights[ptr:nptr]
ptr = nptr
if o["arg_type"] == "image2d_t" or o["arg_type"] == "image1d_t":
tfmt = image_fmt_32 if "float32" in o and o["float32"] else image_fmt
if o["arg_type"] == "image2d_t":
if (
"buffer_id" in o
and o["height"] == 1
and not bufs_loaded[o["buffer_id"]]
):
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(
CL.ctx,
mf.READ_WRITE,
tfmt,
shape=(o["width"],),
buffer=bufs[o["buffer_id"]],
)
else:
# buffer isn't supported in image2d, copy buffer into image
if "buffer_id" in o and bufs_loaded[o["buffer_id"]]:
arr = np.zeros(
bufs[o["buffer_id"]].size // 2, dtype=np.float16
)
cl.enqueue_copy(CL.queue, arr, bufs[o["buffer_id"]])
buf = cl.Image(
CL.ctx,
mf.READ_WRITE | mf.COPY_HOST_PTR,
tfmt,
shape=(o["width"], o["height"]),
pitches=(o["row_pitch"],),
hostbuf=arr,
)
elif o["needs_load"]:
buf = cl.Image(
CL.ctx,
mf.READ_WRITE | mf.COPY_HOST_PTR,
tfmt,
shape=(o["width"], o["height"]),
pitches=(o["row_pitch"],),
hostbuf=o["data"],
)
else:
buf = cl.Image(
CL.ctx,
mf.READ_WRITE,
tfmt,
shape=(o["width"], o["height"]),
)
if o["arg_type"] == "image1d_t":
assert not o["needs_load"]
assert not bufs_loaded[o["buffer_id"]]
buf = cl.Image(
CL.ctx,
mf.READ_WRITE,
tfmt,
shape=(o["width"],),
buffer=bufs[o["buffer_id"]],
)
else:
if "data" in o:
buf = cl.Buffer(
CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o["data"]
)
else:
# zero out buffers
buf = cl.Buffer(
CL.ctx,
mf.READ_WRITE | mf.COPY_HOST_PTR,
hostbuf=b"\x00" * o["size"],
)
bufs[o["id"]] = buf
bufs_loaded[o["id"]] = "data" in o
# if it's loaded, it's saved
if "data" in o:
self.buffers_to_save.add(buf)
# load binaries
prgs = {}
for o in jdat["binaries"]:
nptr = ptr + o["length"]
prgs[o["name"]] = CLProgram(Device["GPU"], o["name"], weights[ptr:nptr])
ptr = nptr
# populate the cl_cache
for i, k in enumerate(jdat["kernels"]):
kernel = prgs[k["name"]]
aaa = []
for j, (a, sz) in enumerate(zip(k["args"], k["args_size"])):
if len(a) == 0:
aa = cl.LocalMemory(sz)
elif len(a) == 4:
a = a.encode("latin_1")
aa = np.uint32(struct.unpack("I", a)[0])
elif len(a) == 2:
a = a.encode("latin_1")
aa = np.uint16(struct.unpack("H", a)[0])
elif len(a) == 8:
# print(i,j,struct.unpack("Q", a.encode('latin_1'))[0])
aa = bufs[a]
aaa.append(aa)
self.cl_cache.append(
(kernel, [k["global_work_size"], k["local_work_size"], *aaa])
)
if DEBUG >= 1:
print(f"thneed: total bufs loaded: {len(bufs.keys())}")
# load inputs
for k in jdat["inputs"]:
self.inputs[k["name"]] = bufs[k["buffer_id"]]
# load outputs
for k in jdat["outputs"]:
self.outputs.append(bufs[k["buffer_id"]])
def save(self, output_fn):
# this is the struct that will be saved
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
# build the pieces of this struct
weights = []
binaries = []
saved_objs = set()
saved_binaries = set()
for prg, args in self.cl_cache:
# get binaries for saving
if prg.name not in saved_binaries:
binary = prg.clprogram.get_info(cl.program_info.BINARIES)
assert len(binary) == 1
jdat["binaries"].append({"name": prg.name, "length": len(binary[0])})
binaries.append(binary[0])
saved_binaries.add(prg.name)
# get the args from the kernel, some need the data saved
targs, args_size = [], []
argdtypes = [None] * (len(args) - 2)
for a, d in zip(args[2:], argdtypes):
if d == np.int16:
targs.append(struct.pack("H", a).decode("latin_1"))
args_size.append(2)
elif d == np.int32:
targs.append(struct.pack("I", a).decode("latin_1"))
args_size.append(4)
elif isinstance(a, cl.LocalMemory):
targs.append("")
args_size.append(a.size)
elif d is None:
if getattr(a, "global_id", None) is None:
setattr(a, "global_id", self.gobj)
self.gobj += 1
ptr = struct.pack("Q", a.global_id).decode("latin_1")
if ptr not in saved_objs:
if isinstance(a, cl.Buffer):
needs_load = a in self.buffers_to_save
jdat["objects"].append(
{
"id": ptr,
"arg_type": "float*",
"needs_load": needs_load,
"size": a.size,
}
)
if needs_load:
data = np.empty(a.size // 4, dtype=np.float32)
cl.enqueue_copy(CL.queue, data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
assert a.format == cl.ImageFormat(
cl.channel_order.RGBA,
cl.channel_type.HALF_FLOAT
if FLOAT16
else cl.channel_type.FLOAT,
), "wrong type"
needs_load = a in self.buffers_to_save
row_pitch = (
(a.shape[0] * 4 * (2 if FLOAT16 else 4) + 63) // 64 * 64
)
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(
CL.ctx,
cl.mem_flags.READ_WRITE,
size=size * (2 if FLOAT16 else 1),
)
# zero out the buffer
cl.enqueue_copy(
CL.queue, buf, b"\x00" * buf.size, is_blocking=True
)
CLProgram(
CL,
"from_image_strided",
compile_gpu(
"""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
"""
),
bufs=2,
vars=1,
)(
a,
buf,
row_pitch // (4 * (2 if FLOAT16 else 4)),
global_size=a.shape,
)
# multiple of 32 isn't enough
jdat["objects"].append(
{
"id": ptr,
"needs_load": needs_load,
"size": size,
"arg_type": "image2d_t",
"width": a.shape[0],
"height": a.shape[1],
"row_pitch": row_pitch,
"float32": not FLOAT16,
}
)
if needs_load:
data = np.empty(
size // (2 if FLOAT16 else 4), dtype=np.float32
)
cl.enqueue_copy(CL.queue, data, buf, is_blocking=True)
if FLOAT16:
data = data.astype(np.float16)
weights.append(data.tobytes())
else:
raise Exception("unknown object", a)
# print(jdat['objects'][-1])
saved_objs.add(ptr)
targs.append(ptr)
args_size.append(8)
else:
raise Exception("idk this type")
# save the kernel itself
jdat["kernels"].append(
{
"name": prg.name,
"work_dim": len(args[0]),
"global_work_size": args[0],
# TODO: C++ thneed requires a local_work_size, so we fill it with ones
"local_work_size": [1 for _ in args[0]]
if args[1] is None
else args[1],
"num_args": len(args) - 2,
"args": targs,
"args_size": args_size,
}
)
jdat["outputs"] = [
{
"buffer_id": struct.pack("Q", x.global_id).decode("latin_1"),
"size": x.size,
}
for x in self.outputs
]
jdat["inputs"] = [
{
"buffer_id": struct.pack("Q", v.global_id).decode("latin_1"),
"size": v.size,
"name": k,
}
for k, v in self.inputs.items()
][::-1]
print(f"saving thneed to {output_fn}")
with open(output_fn, "wb") as f:
j = json.dumps(jdat, ensure_ascii=False).encode("latin_1")
f.write(struct.pack("I", len(j)))
f.write(j)
f.write(b"".join(weights))
f.write(b"".join(binaries))
def run(self):
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprg(CL.queue, *args))
mt = time.monotonic()
Device["GPU"].synchronize()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
if DEBUGCL >= 2:
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
print(
f"{i:3d} {prg.name:25s} "
+ "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms"
% tuple(
(x * OSX_TIMING_RATIO - st * 1e9) / 1e6
for x in [
e.profile.queued,
e.profile.submit,
e.profile.start,
e.profile.end,
]
)
)
if DEBUGCL >= 1:
total_runtime = 0
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(
f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}"
)
if hasattr(prg, "prg") and (
(DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3
):
print(prg.prg)
total_runtime += runtime
print(
f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms"
)
return total_runtime / 1e9
return et