diff --git a/.gitignore b/.gitignore index c42d6ff5a..b9f188d5e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ notebooks build dist *.egg-info +/env \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index dfe798696..f216608b1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -57,6 +57,8 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu) + def test_sum(self): + helper_test_op([(45,1)], lambda x: x.sum(), Tensor.sum, atol=1e-5, gpu=self.gpu) def test_conv2d(self): for bs in [1,8]: diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index 20aa6bbc0..d19f45e89 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -1,6 +1,8 @@ import numpy as np from .tensor import Function, register, Tensor import pyopencl as cl +import pyopencl.array as pycl_array +from pyopencl.reduction import ReductionKernel import functools def buffer_new(ctx, shape): @@ -16,6 +18,10 @@ def buffer_like(ctx, x): def clbuild(cl_ctx, prg): return cl.Program(cl_ctx, prg).build() +@functools.lru_cache +def cl_reduct_krnl_build(cl_ctx, *args, **kwargs): + return ReductionKernel(cl_ctx, *args, **kwargs) + def binary_op(ctx, code, x, y): ret = buffer_like(ctx, x) prg = clbuild(ctx.cl_ctx, """ @@ -105,16 +111,11 @@ class Sum(Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) - ret = buffer_new(ctx, (1,)) - prg = clbuild(ctx.cl_ctx, """ - __kernel void sum( - __global const float *a_g, __global float *res_g) - { - int gid = get_global_id(0); - res_g[0] += a_g[gid]; - } - """) - prg.sum(ctx.cl_queue, [input.size//4], None, input, ret) + krnl = cl_reduct_krnl_build(ctx.cl_ctx, np.float32, neutral="0", reduce_expr="a+b", + map_expr="x[i]", arguments="__global float *x") + ret = krnl(pycl_array.Array(ctx.cl_queue, input.size, dtype=np.float32, data=input)).data + ret.shape = (1,) + ret.dtype = np.float32 return ret @staticmethod