From 5068e99d18f1800a5462724e2d8f75edfad2e069 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 2 Dec 2023 00:32:25 -0800 Subject: [PATCH] refactor to remove extra kernel params (#2563) * refactor to have compiled kernel * bugfixes * docs/beautiful.py * revert that * fix tests --- .github/workflows/test.yml | 5 +- .pre-commit-config.yaml | 10 ++- docs/abstractions.py | 2 +- docs/beautiful.py | 94 ++++++++++++++++++++++ test/external/external_test_speed_llama.py | 2 +- test/test_custom_function.py | 2 +- test/test_uops.py | 8 +- tinygrad/device.py | 6 +- tinygrad/runtime/ops_clang.py | 5 +- tinygrad/runtime/ops_cuda.py | 13 +-- tinygrad/runtime/ops_gpu.py | 11 ++- tinygrad/runtime/ops_hip.py | 8 +- tinygrad/runtime/ops_llvm.py | 8 +- tinygrad/runtime/ops_metal.py | 4 +- tinygrad/runtime/ops_webgpu.py | 2 +- 15 files changed, 143 insertions(+), 37 deletions(-) create mode 100644 docs/beautiful.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2a813b04e..6fc8a5f59 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,6 +16,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 20 + # TODO: run the pre-commit hook to replace a lot of this steps: - name: Checkout Code uses: actions/checkout@v3 @@ -47,7 +48,9 @@ jobs: - name: Check <5000 lines run: sloccount tinygrad test examples extra; if [ $(sloccount tinygrad | sed -n 's/.*Total Physical Source Lines of Code (SLOC)[ ]*= \([^ ]*\).*/\1/p' | tr -d ',') -gt 5000 ]; then exit 1; fi - name: Test Docs - run: python docs/abstractions.py + run: | + python docs/abstractions.py + python docs/beautiful.py - name: Test Quickstart run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python quickstart.py - name: Fuzz Test symbolic diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06753d413..c1a8072fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,15 @@ repos: pass_filenames: false - id: docs name: docs - entry: python3 docs/abstractions.py + entry: | + python3 docs/abstractions.py + python3 docs/beautiful.py + language: system + always_run: true + pass_filenames: false + - id: devicetests + name: select GPU tests + entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py language: system always_run: true pass_filenames: false diff --git a/docs/abstractions.py b/docs/abstractions.py index 5892a28f4..decb5d521 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -214,7 +214,7 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B")) MallocAllocator.copyin(input_b, numpy_b.data.cast("B")) # compile the program, run it, and 2+3 does indeed equal 5 -program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"), bufs=3) +program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")) program(output, input_a, input_b) numpy_out = np.empty(1, dtype=np.float32) MallocAllocator.copyout(numpy_out.data.cast("B"), output) diff --git a/docs/beautiful.py b/docs/beautiful.py new file mode 100644 index 000000000..48fd9a49f --- /dev/null +++ b/docs/beautiful.py @@ -0,0 +1,94 @@ +# in tinygrad, things are easy +# because if things were hard, tinygrad would not be tiny +# come on a journey where we add 2+3 + + +# ******** first, in the raw *********** + +from tinygrad import dtypes +from tinygrad.device import MallocAllocator +from tinygrad.runtime.ops_clang import ClangProgram, compile_clang + +# allocate some buffers +# TODO: remove dtypes from allocators +out = MallocAllocator.alloc(1, dtypes.int32) +a = MallocAllocator.alloc(1, dtypes.int32) +b = MallocAllocator.alloc(1, dtypes.int32) + +# load in some values +MallocAllocator.copyin(a, bytearray([2,0,0,0])) +MallocAllocator.copyin(b, bytearray([3,0,0,0])) + +# compile a program +fxn = ClangProgram("add", compile_clang("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")) + +# run the program +fxn(out, a, b) +outb = bytearray(MallocAllocator.as_buffer(out)) +assert outb == bytearray([5,0,0,0]) +print(outb) + + +# ******** second, one layer higher *********** + +import numpy as np +from tinygrad.device import Buffer, Device +from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps +from tinygrad.shape.shapetracker import ShapeTracker + +# allocate some buffers + load in values +out = Buffer("CLANG", 1, dtypes.int32) +a = Buffer("CLANG", 1, dtypes.int32).copyin(np.array([2], np.int32).data) +b = Buffer("CLANG", 1, dtypes.int32).copyin(np.array([3], np.int32).data) + +# describe the computation +ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))) +ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))) +alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2)) +st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))) + +# compile a program +fxn = Device["CLANG"].get_runner(st_0) + +# run the program +fxn.exec([out, a, b]) +assert out.toCPU().item() == 5 +print(out.toCPU()) + + +# ******** third, one layer higher *********** + +from tinygrad.lazy import LazyBuffer +from tinygrad.graph import print_tree +from tinygrad.realize import run_schedule + +# allocate some values + load in values +a = LazyBuffer.fromCPU(np.array([2], np.int32)) +b = LazyBuffer.fromCPU(np.array([3], np.int32)) + +# describe the computation +out = a.e(BinaryOps.ADD, b) + +# schedule the computation (print it) +sched = out.schedule() +print_tree(sched[0].ast) + +# run that schedule +run_schedule(sched) +assert out.realized.toCPU().item() == 5 +print(out.realized.toCPU()) + + +# ******** fourth, the top layer *********** + +from tinygrad import Tensor + +a = Tensor([2], dtype=dtypes.int32) +b = Tensor([3], dtype=dtypes.int32) +out = (a+b).item() +assert out == 5 +print(out) + + + + diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 712c7faa2..caa995c2d 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -8,7 +8,7 @@ from tinygrad.device import Compiled, Allocator from tinygrad.helpers import Profiling class FakeProgram: - def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0): pass + def __init__(self, name:str, prg:bytes): pass def __call__(self, *bufs, global_size, local_size, wait=False): pass class FakeAllocator(Allocator): diff --git a/test/test_custom_function.py b/test/test_custom_function.py index df2f40959..659a85e4d 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -22,7 +22,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer): __kernel void atan2_gpu(global float *c, global float *a, global float *b) { int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); - }""", global_size=[ret.size], bufcount=3).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret, a, b]) + }""", global_size=[ret.size]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret, a, b]) def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data) diff --git a/test/test_uops.py b/test/test_uops.py index 9a28d5c40..1ba35c547 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,11 +7,11 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp -def _uops_to_prg(uops, bufcount): +def _uops_to_prg(uops): src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) return CompiledASTRunner(None, "test", src, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, - runtime_args=runtime_args, bufcount=bufcount).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) + runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg)) @@ -26,7 +26,7 @@ def _test_single_value(vals, op, dtype): uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, dtype) buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a in vals] - prg = _uops_to_prg(uops, 1+len(buf2)) + prg = _uops_to_prg(uops) prg.exec([buf]+buf2) return buf.toCPU()[0] @@ -37,7 +37,7 @@ def _test_single_value_const(vals, op, dtype): alu = uop(uops, UOps.ALU, dtype, loads, op) uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, dtype) - prg = _uops_to_prg(uops, 1) + prg = _uops_to_prg(uops) prg.exec([buf]) return buf.toCPU()[0] diff --git a/tinygrad/device.py b/tinygrad/device.py index cb1e99d94..72a83b6d1 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -213,18 +213,16 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret # **************** for Compiled Devices **************** class CompiledASTRunner(JITRunner): - def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None, bufcount:int=0): + def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): super().__init__() if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \ to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {} - self.bufcount = bufcount self.vars: List[Variable] = [] if ast: info = get_lazyop_info(ast) - self.bufcount = len(info.mem) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate from tinygrad.lazy import vars_from_ast self.vars = vars_from_ast(ast) @@ -232,7 +230,7 @@ class CompiledASTRunner(JITRunner): def build(self, compiler, runtime): self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) - self.clprg = runtime(self.name, self.lib, self.bufcount, len(self.vars)) + self.clprg = runtime(self.name, self.lib) return self def launch_dims(self, var_vals): diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 6c7329873..611f22983 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -20,10 +20,11 @@ def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes: return pathlib.Path(output_file.name).read_bytes() class ClangProgram: - def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0): + def __init__(self, name:str, lib:bytes): + self.name, self.lib = name, lib # write to disk so we can load it with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: - pathlib.Path(cached_file_path.name).write_bytes(prg) + pathlib.Path(cached_file_path.name).write_bytes(lib) self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] def __call__(self, *args, wait=False): return cpu_time_execution(lambda: self.fxn(*args), enable=wait) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 5a08eca0e..b7726ef15 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -23,20 +23,21 @@ def cu_time_execution(cb, enable=False) -> Optional[float]: return time_executio def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', CUDA_INCLUDE_PATH], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) class CUDAProgram: - def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0): - if DEBUG >= 5: print(pretty_ptx(prg.decode('utf-8'))) + def __init__(self, name:str, lib:bytes): + self.name, self.lib = name, lib + if DEBUG >= 5: print(pretty_ptx(lib.decode('utf-8'))) if DEBUG >= 6: try: - fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(prg).hexdigest()}").as_posix() - with open(fn + ".ptx", "wb") as f: f.write(prg) + fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix() + with open(fn + ".ptx", "wb") as f: f.write(lib) subprocess.run(["ptxas", f"-arch={CUDADevice.default_arch_name}", "-o", fn, fn+".ptx"], check=True) print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) except Exception as e: print("failed to generate SASS", str(e)) if not CUDACPU: - self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), prg))) + self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib))) check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) - self.prg = prg + self.prg = prg if not CUDACPU else lib def __del__(self): if not CUDACPU: check(cuda.cuModuleUnload(self.module)) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 08d02afd2..746055ffc 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -34,15 +34,14 @@ def compile_cl(prg:str) -> bytes: return bytes(binary) class CLProgram: - def __init__(self, device:CLDevice, name:str, prg:bytes, bufs:int=0, vars:int=0): - self.device = device - self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(prg)), - to_char_p_p([prg], ctypes.c_ubyte), + def __init__(self, device:CLDevice, name:str, lib:bytes): + self.device, self.name, self.lib = device, name, lib + self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)), + to_char_p_p([lib], ctypes.c_ubyte), ctypes.byref(binary_status := ctypes.c_int32()), ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret) check(binary_status.value) check(cl.clBuildProgram(self.program, 1, ctypes.byref(device.device_id), None, cl.clBuildProgram.argtypes[4](), None)) # NOTE: OSX requires this self.kernel = checked(cl.clCreateKernel(self.program, name.encode(), ctypes.byref(status := ctypes.c_int32())), status) - self.vars = vars def __del__(self): check(cl.clReleaseKernel(self.kernel)) @@ -50,7 +49,7 @@ class CLProgram: def __call__(self, *bufs:Union[cl.cl_mem, int], global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, wait=False) -> Optional[float]: for i,b in enumerate(bufs): - bc = ctypes.c_int32(b) if i >= (len(bufs)-self.vars) else cast(cl.cl_mem, b) + bc = ctypes.c_int32(b) if isinstance(b, int) else cast(cl.cl_mem, b) cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(bc), ctypes.byref(bc)) if local_size is not None: global_size = tuple(int(g*l) for g,l in zip(global_size, local_size)) event = cl.cl_event() if wait else None diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index b5aca40d5..642ec274b 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -23,16 +23,16 @@ def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, h def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) class HIPProgram: - def __init__(self, device:int, name:str, prg:bytes, bufs:int, vars:int=0): - self.device = device + def __init__(self, device:int, name:str, lib:bytes): + self.device, self.name, self.lib = device, name, lib if DEBUG >= 6: - asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg)) + asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], lib)) print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) if MOCKHIP: return check(hip.hipSetDevice(self.device)) - self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), prg))) + self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib))) self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8")))) def __del__(self): diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index ad2aa1dc2..af0d9e4a8 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -54,11 +54,13 @@ def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes: return LLVM.target_machine.emit_object(mod) class LLVMProgram: - def __init__(self, name:str, lib:bytes, bufs:int, vars:int=0): + def __init__(self, name:str, lib:bytes): + self.name, self.lib = name, lib LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) self.fxn = LLVM.engine.get_function_address(name) - self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*bufs), *([ctypes.c_int]*vars))(self.fxn) - def __call__(self, *bufs, wait=False): return cpu_time_execution(lambda: self.cfunc(*bufs), enable=wait) + def __call__(self, *bufs, wait=False): + self.cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_int32 if isinstance(b, int) else ctypes.c_void_p for b in bufs])(self.fxn) + return cpu_time_execution(lambda: self.cfunc(*bufs), enable=wait) LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index d6dc53046..1a72e61a5 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -19,8 +19,8 @@ def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes: return library.libraryDataContents().bytes().tobytes() class MetalProgram: - def __init__(self, device:MetalDevice, name:str, lib:bytes, bufs:int, vars:int=0): - self.device = device + def __init__(self, device:MetalDevice, name:str, lib:bytes): + self.device, self.name, self.lib = device, name, lib data = libdispatch.dispatch_data_create(lib, len(lib), None, None) self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None)) self.fxn = self.library.newFunctionWithName_(name) diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index d30febef9..6c098eb82 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -10,7 +10,7 @@ import wgpu wgpu_device = get_default_device() class WebGPUProgram: - def __init__(self, name: str, prg: str, bufs:int=0, vars:int=0): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) + def __init__(self, name:str, lib:bytes): self.name, self.lib, self.prg = name, lib, wgpu_device.create_shader_module(code=lib) # NOTE: this is the compiler def __call__(self, *bufs, global_size, local_size, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]