diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index cea8efe63..0dd5cca0f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -27,7 +27,7 @@ jobs: ln -s ~/tinygrad/weights/LLaMA weights/LLaMA ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz - name: Run model inference benchmark - run: python3 test/external/external_model_benchmark.py + run: METAL=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Run Tensor Core GEMM @@ -73,8 +73,8 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - #- name: Run model inference benchmark - # run: CUDA=1 python3 test/external/external_model_benchmark.py + - name: Run model inference benchmark + run: CUDA=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Run GPT2 @@ -109,7 +109,7 @@ jobs: ln -s ~/tinygrad/weights/LLaMA weights/LLaMA ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz - name: Run model inference benchmark - run: python3 test/external/external_model_benchmark.py + run: GPU=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch run: BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Run Tensor Core GEMM diff --git a/test/external/external_cl_half_max.py b/test/external/external_cl_half_max.py new file mode 100644 index 000000000..7cd6b0c50 --- /dev/null +++ b/test/external/external_cl_half_max.py @@ -0,0 +1,13 @@ +from tinygrad.runtime.ops_gpu import CLDevice, CLProgram, compile_cl + +if __name__ == "__main__": + dev = CLDevice() + lib = compile_cl(""" +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void test(__global half *out, __global half *a, __global half *b) { + int gid = get_global_id(0); + out[gid] = max(a[gid], b[gid]); +} +""") + prg = CLProgram(dev, "test", lib) + diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 4c9511389..01742126d 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -43,7 +43,7 @@ def benchmark(mnm, nm, fxn): #BASE = pathlib.Path(__file__).parents[2] / "weights" / "onnx" BASE = pathlib.Path("/tmp/onnx") -def benchmark_model(m, validate_outs=False): +def benchmark_model(m, devices, validate_outs=False): torch.manual_seed(1) global open_csv, CSV CSV = {"model": m} @@ -61,7 +61,7 @@ def benchmark_model(m, validate_outs=False): # print input names if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded]) - for device in ["METAL" if OSX else "GPU", "CLANG"]: # + (["CUDA"] if torch.cuda.is_available() else []): + for device in devices: Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} tinygrad_model = get_run_onnx(onnx_model) @@ -92,11 +92,14 @@ def benchmark_model(m, validate_outs=False): provider = backend+"ExecutionProvider" if provider not in ort.get_available_providers(): continue ort_sess = ort.InferenceSession(str(fn), ort_options, [provider]) - benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs)) + try: + benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs)) + except Exception as e: print(f"{m:16s}onnxruntime_{backend.lower()} {type(e).__name__:>25}") del ort_sess if validate_outs: rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models + if m == "openpilot" and 'CUDA' in devices: rtol, atol = 0.1, 0.1 # TODO: why is this broken? inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} tinygrad_model = get_run_onnx(onnx_model) tinygrad_out = tinygrad_model(inputs) @@ -121,6 +124,7 @@ def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5): else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}") if __name__ == "__main__": - if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), True) + devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CLANG"] + if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), devices, True) else: - for m in MODELS: benchmark_model(m, True) + for m in MODELS: benchmark_model(m, devices, True) diff --git a/tinygrad/device.py b/tinygrad/device.py index 37bbff460..8885ab604 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -141,15 +141,17 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method def alloc(self, size:int): if len(c := self.cache[size]): return c.pop() try: - return self._alloc(size) + return super().alloc(size) except MemoryError: self.free_cache() - return self._alloc(size) + return super().alloc(size) def free_cache(self): for opaques in self.cache.values(): for opaque in opaques: self._free(opaque) opaques.clear() - def free(self, opaque:Any, size:int): self.cache[size].append(opaque) + def free(self, opaque:Any, size:int): + if getenv("LRU", 1): self.cache[size].append(opaque) + else: self._free(opaque) class _MallocAllocator(LRUAllocator): def _alloc(self, size:int): return (ctypes.c_uint8 * size)() diff --git a/tinygrad/jit.py b/tinygrad/jit.py index a68617a06..1e2337a20 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -74,14 +74,16 @@ class TinyJit(Generic[ReturnType]): self.ret = self.fxn(*args, **kwargs) self.jit_cache = CacheCollector.finish() assert len(self.jit_cache) != 0, "didn't JIT anything!" - if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # if your Device supports it, condense the items into a graph executor if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2: try: + if DEBUG >= 1: print(f"JIT GRAPHing {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))] except GraphException as e: if DEBUG >= 1: print(f"graph create failed {e}") + else: + if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) elif self.cnt == 0: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index f359e9d12..c2d4043c6 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -249,16 +249,11 @@ class CUDALanguage(CStyleLanguage): gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] + code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})"} half_prekernel = """ #include - #include - using namespace nvcuda; - struct __align__(8) half4 { - half2 x, y; - __device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {} - __device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); } - }; - """ + struct half4 { half x, y, z, w; }; + """ CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) class HIPLanguage(CStyleLanguage): diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 740e8509f..22927b0e5 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -1,4 +1,5 @@ -import subprocess, hashlib, tempfile, ctypes, ctypes.util +from __future__ import annotations +import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools from pathlib import Path from typing import Tuple, Optional import gpuctypes.cuda as cuda @@ -7,7 +8,6 @@ from tinygrad.device import Compiled, LRUAllocator, MallocAllocator from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import CUDARenderer -CUDA_INCLUDE_PATH = getenv("CUDA_INCLUDE_PATH", default="-I/usr/local/cuda/include") CUDACPU = getenv("CUDACPU") == 1 if CUDACPU: gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot")) @@ -20,11 +20,11 @@ def check(status): def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) @diskcache -def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', CUDA_INCLUDE_PATH], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) +def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) class CUDAProgram: - def __init__(self, name:str, lib:bytes): - self.name, self.lib = name, lib + def __init__(self, device:CUDADevice, name:str, lib:bytes): + self.device, self.name, self.lib = device, name, lib if DEBUG >= 5: print(pretty_ptx(lib.decode('utf-8'))) if DEBUG >= 6: try: @@ -35,6 +35,7 @@ class CUDAProgram: except Exception as e: print("failed to generate SASS", str(e)) if not CUDACPU: + check(cuda.cuCtxSetCurrent(self.device.context)) self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib))) check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) self.prg = prg if not CUDACPU else lib @@ -43,28 +44,39 @@ class CUDAProgram: if not CUDACPU: check(cuda.cuModuleUnload(self.module)) def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False): + if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) c_kernel_input_config = encode_args_cuda_style(args, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else args return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) class CUDAAllocator(LRUAllocator): - def _alloc(self, size): return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size))) + def __init__(self, device:CUDADevice): + self.device = device + super().__init__() + def _alloc(self, size): + check(cuda.cuCtxSetCurrent(self.device.context)) + return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size))) def _free(self, opaque): check(cuda.cuMemFree_v2(opaque)) - def copyin(self, dest, src:memoryview): check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None)) - def copyout(self, dest:memoryview, src): check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest))) + def copyin(self, dest, src:memoryview): + check(cuda.cuCtxSetCurrent(self.device.context)) + check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None)) + def copyout(self, dest:memoryview, src): + check(cuda.cuCtxSetCurrent(self.device.context)) + check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest))) class CUDADevice(Compiled): default_arch_name = "sm_35" def __init__(self, device:str): - self.device = int(device.split(":")[1]) if ":" in device else 0 + device_id = int(device.split(":")[1]) if ":" in device else 0 if not CUDACPU: check(cuda.cuInit(0)) - check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), self.device)) - check(cuda.cuCtxCreate_v2(ctypes.byref(_ := cuda.CUcontext()), 0, device)) - check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), self.device)) - if self.device == 0: CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}" + check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), device_id)) + check(cuda.cuCtxCreate_v2(ctypes.byref(context := cuda.CUcontext()), 0, device)) + self.context = context + check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id)) + if device_id == 0: CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}" from tinygrad.features.graph.cuda import CUDAGraph - super().__init__(CUDAAllocator() if not CUDACPU else MallocAllocator, + super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator, LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]), - CUDARenderer, compile_cuda, CUDAProgram, graph=CUDAGraph if not CUDACPU else None) + CUDARenderer, compile_cuda, functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None) def synchronize(self): return check(cuda.cuCtxSynchronize()) if not CUDACPU else None