1
0
Fork 0

GPU llops

This commit is contained in:
George Hotz 2022-06-05 13:49:39 -07:00
parent f0fe37bd34
commit 7a3fe34db1
3 changed files with 158 additions and 142 deletions

View file

@ -1,18 +1,18 @@
Getting the core instruction set correct is the value of tinygrad
Max size tensor is 6-D for the pool2d
Unary Ops
===
These are the simplest to reason about, and have pointwise mem access.
A and B are always the same size
Forward : A -> B
Backward (binary): (B', A) -> A'
Reduce Ops (with axis)
===

141
tinygrad/llops/gpu.py Normal file
View file

@ -0,0 +1,141 @@
# llops don't know about derivatives
import functools
import numpy as np
import pyopencl as cl
from tinygrad.helpers import binary_broadcast
i32 = np.int32
cl_ctx, cl_queue = None, None
def require_init_gpu():
global cl_ctx, cl_queue
if cl_ctx is None:
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
if len(devices) == 0:
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU)
cl_ctx = cl.Context(devices=devices)
# this is an in-order command queue
cl_queue = cl.CommandQueue(cl_ctx)
class GPUBuffer:
def __init__(self, shape, hostbuf=None):
require_init_gpu()
self.shape, self.dtype = tuple(shape), np.float32
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \
cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape),
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
def __repr__(self):
return f"<GPUBuffer with shape {self.shape!r}>"
@staticmethod
def fromCPU(x):
return GPUBuffer(x.shape, x.view(np.ndarray))
def toCPU(self):
data = np.empty(self.shape, dtype=np.float32)
cl_queue.finish()
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
return data
def buffer_new(ctx, shape, zero=False):
return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32))
def buffer_np(ctx, x):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
def clbuffer(hostbuf, shape):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0),
4*np.prod(shape),
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
@functools.lru_cache
def clbuild(name, prg):
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)
def run(*args):
clprg(cl_queue, *args)
return run
# x -> ret
def unary_op(ctx, code, x):
ret = buffer_new(ctx, x.shape)
unop = clbuild("unop", """
__kernel void unop(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0);
float a = a_g[gid];
res_g[gid] = """+code+""";
}""")
unop([np.prod(ret.shape)], None, x.cl, ret.cl)
return ret
@functools.lru_cache
def get_binop_prg(cl_ctx, code, complist):
ndims = len(complist)
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
idx_exprs = ["0", "0"] # [idx_x, idx_y]
for i in range(ndims):
for j in range(2):
if complist[i][j]:
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
return cl.Program(cl_ctx, """__kernel void binop(__global const float *x_g, __global const float *y_g, __global float *res_g"""+args+""") {
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
float a = x_g["""+idx_exprs[0]+"""];
float b = y_g["""+idx_exprs[1]+"""];
res_g[gid0] = """+code+""";\n}""").build()
def binary_op(ctx, code, x, y):
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape)
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
prg = get_binop_prg(cl_ctx, code, tuple(complist))
ret = buffer_new(ctx, shape_ret, zero=True)
prg.binop(cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
return ret
def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
if axis is None:
# full reduce
osize = [1]*len(inp.shape)
else:
osize = np.array(inp.shape)
osize[list(axis)] = 1
ret = buffer_new(ctx, osize)
if axis is None:
ret.shape = (1,)
# TODO: this is insanely slow
reduce = clbuild("reduce", """
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
__global const int *shape_x, __global const int *shape_ret) {
int gid = get_global_id(0);
float out = """+start+""";
for (int x = 0; x < sz; x++) {
int idx = 0; // compute index into a_g
int tprod = prod;
int tsz = sz;
for (int dim = 0; dim < n_dims; dim++) {
idx *= shape_x[dim];
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
tprod /= shape_x[dim];
idx += (gid / tprod) % shape_x[dim];
} else { // dim from x
tsz /= shape_x[dim];
idx += (x / tsz) % shape_x[dim];
}
}
float a = a_g[idx];
"""+code+""";
}
res_g[gid] = """+code2+""";
}""")
reduce([np.prod(osize)], None, inp.cl,
i32(np.prod(inp.shape)//np.prod(osize)), ret.cl,
i32(np.prod(osize)), i32(len(osize)),
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
buffer_np(ctx, np.array(osize, dtype=np.int32)))
return ret

View file

@ -1,50 +1,7 @@
import functools
import pyopencl as cl
import numpy as np
from tinygrad.helpers import binary_broadcast
from ..tensor import Function
cl_ctx, cl_queue = None, None
def require_init_gpu():
global cl_ctx, cl_queue
if cl_queue is None:
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
if len(devices) == 0:
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU)
cl_ctx = cl.Context(devices=devices)
# this is an in-order command queue
cl_queue = cl.CommandQueue(cl_ctx)
class GPUBuffer:
def __init__(self, shape, hostbuf=None):
require_init_gpu()
self.shape, self.dtype = tuple(shape), np.float32
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \
cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape),
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
def __repr__(self):
return f"<GPUBuffer with shape {self.shape!r}>"
@staticmethod
def fromCPU(x):
return GPUBuffer(x.shape, x.view(np.ndarray))
def toCPU(self):
data = np.empty(self.shape, dtype=np.float32)
cl_queue.finish()
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
return data
def buffer_new(ctx, shape, zero=False):
return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32))
def buffer_np(ctx, x):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
@functools.lru_cache
def clbuild(cl_ctx, name, prg):
return cl.Program(cl_ctx, prg).build().__getattr__(name)
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, buffer_np, unary_op, binary_op, reduce_op
def uint2(x, y):
return np.array((x,y), dtype=cl.cltypes.uint2)
@ -52,17 +9,6 @@ i32 = np.int32
# ************* unary ops *************
def unary_op(ctx, code, x):
ret = buffer_new(ctx, x.shape)
unop = clbuild(cl_ctx, "unop", """
__kernel void unop(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0);
float a = a_g[gid];
res_g[gid] = """+code+""";
}""")
unop(cl_queue, [np.prod(ret.shape)], None, x.cl, ret.cl)
return ret
class ReLU(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
@ -93,50 +39,6 @@ class Exp(Function):
# ************* reduce ops *************
def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
if axis is None:
# full reduce
osize = [1]*len(inp.shape)
else:
osize = np.array(inp.shape)
osize[list(axis)] = 1
ret = buffer_new(ctx, osize)
if axis is None:
ret.shape = (1,)
# TODO: this is insanely slow
reduce = clbuild(cl_ctx, "reduce", """
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
__global const int *shape_x, __global const int *shape_ret) {
int gid = get_global_id(0);
float out = """+start+""";
for (int x = 0; x < sz; x++) {
int idx = 0; // compute index into a_g
int tprod = prod;
int tsz = sz;
for (int dim = 0; dim < n_dims; dim++) {
idx *= shape_x[dim];
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
tprod /= shape_x[dim];
idx += (gid / tprod) % shape_x[dim];
} else { // dim from x
tsz /= shape_x[dim];
idx += (x / tsz) % shape_x[dim];
}
}
float a = a_g[idx];
"""+code+""";
}
res_g[gid] = """+code2+""";
}""")
reduce(cl_queue, [np.prod(osize)], None, inp.cl,
i32(np.prod(inp.shape)//np.prod(osize)), ret.cl,
i32(np.prod(osize)), i32(len(osize)),
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
buffer_np(ctx, np.array(osize, dtype=np.int32)))
return ret
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
@ -162,33 +64,6 @@ class Max(Function):
# ************* binary ops *************
@functools.lru_cache
def get_binop_prg(cl_ctx, code, complist):
ndims = len(complist)
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
idx_exprs = ["0", "0"] # [idx_x, idx_y]
for i in range(ndims):
for j in range(2):
if complist[i][j]:
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
return cl.Program(cl_ctx, """__kernel void binop(__global const float *x_g, __global const float *y_g, __global float *res_g"""+args+""") {
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
float a = x_g["""+idx_exprs[0]+"""];
float b = y_g["""+idx_exprs[1]+"""];
res_g[gid0] = """+code+""";\n}""").build()
def binary_op(ctx, code, x, y):
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape)
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
prg = get_binop_prg(cl_ctx, code, tuple(complist))
ret = buffer_new(ctx, shape_ret, zero=True)
prg.binop(cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
return ret
def unbroadcast(ctx, out, in_sh):
sum_axis = [i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1] if in_sh != (1,) else None
return reduce_op(ctx, "out += a", "out", out, sum_axis)
@ -254,7 +129,7 @@ class Reshape(Function):
def perm_axis(ctx, inp, order):
osize = np.array(inp.shape)[list(order)]
ret = buffer_new(ctx, osize)
perm = clbuild(cl_ctx, "perm", """
perm = clbuild("perm", """
__kernel void perm(__global const float *a_g, __global float *res_g, int n_axis,
__global const int *shape, __global const int *order) {
int gid = get_global_id(0);
@ -268,7 +143,7 @@ def perm_axis(ctx, inp, order):
}
res_g[gid] = a_g[idx];
}""")
perm(cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
perm([np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
buffer_np(ctx, np.array(order, dtype=np.int32)))
return ret
@ -286,7 +161,7 @@ def inner_slice(ctx, x, arg):
shift = [y[0] for y in arg]
oshape = [y[1]-y[0] for y in arg]
ret = buffer_new(ctx, oshape)
gslice = clbuild(cl_ctx, "gslice", """
gslice = clbuild("gslice", """
__kernel void gslice(__global const float *input, __global float *output, int prod, int n_dims,
__global const int *shape_x, __global const int *shape_ret,
__global const int *shift) {
@ -301,7 +176,7 @@ def inner_slice(ctx, x, arg):
}
output[gid] = zero ? input[iptr] : 0.0;
}""")
gslice(cl_queue, [np.prod(ret.shape)], None,
gslice([np.prod(ret.shape)], None,
x.cl, ret.cl, i32(np.prod(ret.shape)), i32(len(ret.shape)),
buffer_np(ctx, np.array(x.shape, dtype=np.int32)),
buffer_np(ctx, np.array(ret.shape, dtype=np.int32)),
@ -327,7 +202,7 @@ class Matmul(Function):
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize])
matmul = clbuild(cl_ctx, "matmul", """
matmul = clbuild("matmul", """
__kernel void matmul(
__global const float *input, __global const float *weight, __global float *res,
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
@ -348,7 +223,7 @@ class Matmul(Function):
ctx.save_for_backward(input, weight, matmul, cnt)
# (isize,msize) x (msize,osize) = (isize,osize)
matmul(cl_queue, [isize, osize, cnt], None,
matmul([isize, osize, cnt], None,
input.cl, weight.cl, ret.cl, isize,
msize, i32(1), msize, i32(1), osize, osize)
return ret
@ -361,12 +236,12 @@ class Matmul(Function):
grad_weight = buffer_new(ctx, weight.shape)
# (isize,osize) x (msize,osize) = (isize,msize)
matmul(cl_queue, [isize, msize, cnt], None,
matmul([isize, msize, cnt], None,
grad_output.cl, weight.cl, grad_input.cl, isize,
osize, i32(1), osize, osize, i32(1), msize)
# (isize,msize) x (isize,osize) = (msize,osize)
matmul(cl_queue, [msize, osize, cnt], None,
matmul([msize, osize, cnt], None,
input.cl, grad_output.cl, grad_weight.cl, msize,
i32(1), msize, isize, i32(1), osize, osize)
@ -392,7 +267,7 @@ class Conv2D(Function):
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
conv = clbuild(cl_ctx, "conv", """
conv = clbuild("conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
@ -417,7 +292,7 @@ class Conv2D(Function):
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv(cl_queue, [bs*groups*rcout, oy, ox], None,
conv([bs*groups*rcout, oy, ox], None,
x.cl, w.cl, ret.cl,
i32(H), i32(W), i32(groups), i32(rcout), i32(cin),
i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs)
@ -442,7 +317,7 @@ class Conv2D(Function):
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
convw = clbuild(cl_ctx, "convw", """
convw = clbuild("convw", """
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
@ -463,7 +338,7 @@ class Conv2D(Function):
}
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convx = clbuild(cl_ctx, "convx", """
convx = clbuild("convx", """
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
@ -489,6 +364,6 @@ class Conv2D(Function):
""")
conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs)
convw(cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
convx(cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
return dx, dw