1
0
Fork 0

new style device (#2530)

* cpu tests pass

* torch works

* works

* metal works

* fix ops_disk

* metal jit works

* fix openpilot

* llvm and clang work

* fix webgpu

* docs are rly broken

* LRU works on metal

* delete comment

* revert name to ._buf. LRU only on Compiled

* changes

* allocator

* allocator, getting closer

* lru alloc

* LRUAllocator

* all pass

* metal

* cuda

* test examples

* linearizer

* test fixes

* fix custom + clean realize

* fix hip

* skip tests

* fix tests

* fix size=0

* fix MOCKHIP

* fix thneed

* copy better

* simple

* old style metal copy

* fix thneed

* np reshape

* give cuda a device
pull/2531/head
George Hotz 2023-11-30 17:07:16 -08:00 committed by GitHub
parent e56511b59a
commit 2c363b5f0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 572 additions and 1039 deletions

View File

@ -197,11 +197,11 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot fastvits model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test multigpu
run: |
PYTHONPATH="." python test/external/dist/test_world.py
PYTHONPATH="." python test/external/dist/test_collectives.py
#- if: ${{ matrix.task == 'openpilot' }}
# name: Test multigpu
# run: |
# PYTHONPATH="." python test/external/dist/test_world.py
# PYTHONPATH="." python test/external/dist/test_collectives.py
- if: ${{ matrix.task == 'onnx' }}
name: Test ONNX (CPU)
run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20

2
.gitignore vendored
View File

@ -45,3 +45,5 @@ coverage.xml
htmlcov
outputs_yolov8
wandb
model.safetensors
quickstart.py

View File

@ -83,9 +83,9 @@ class LazyBuffer:
# we'll come back to this later
st: ShapeTracker
# if the LazyBuffer is realized, it has a RawBuffer
# we will come back to RawBuffers later
realized: Optional[RawBuffer]
# if the LazyBuffer is realized, it has a Buffer
# we will come back to Buffer later
realized: Optional[Buffer]
# if the lazybuffer is unrealized, it has a LazyOp
# this LazyOp describes the computation needed to realize this LazyBuffer
@ -142,7 +142,7 @@ assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
result.realize()
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
#assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
@ -153,9 +153,6 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU,
# Interpreted backends are very simple (example: CPU and TORCH)
class Interpreted:
# they have a backing RawBuffer
buffer: Type[RawBuffer]
# and they have a lookup table to functions for the Ops
fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP2: lambda x: np.exp2(x),
@ -163,9 +160,6 @@ class Interpreted:
# Compiled backends take a little more (example: GPU and LLVM)
class Compiled:
# they also have a backing RawBuffer
buffer: Type[RawBuffer]
# a code generator, which compiles the AST
codegen: Type[Linearizer]
@ -178,41 +172,28 @@ class Runtime(ABC):
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, global_size:Optional[List[int]], local_size:Optional[List[int]], *bufs:List[RawBuffer]): pass
def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
# %%
# == RawBuffer (in tinygrad/runtime/lib.py, code 5/10) ==
# == Buffer (in tinygrad/device.py, code 6/10) ==
import numpy as np
# RawBuffer is where the data is actually held. it's pretty close to just memory
class RawBuffer(ABC):
# Buffer is where the data is actually held. it's pretty close to just memory
class Buffer(ABC):
# create an empty rawbuffer that holds `size` elements of type `dtype`
# `buf` is an opaque container class
def __init__(self, size:int, dtype:DType, buf:Any): raise NotImplementedError("must be implemented")
# `opaque` is an opaque container class
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass
# fromCPU is classmethod that creates a RawBuffer, it's a classmethod since some runtimes are 0 copy
@classmethod
def fromCPU(cls:RawBuffer, x:np.ndarray) -> RawBuffer: raise NotImplementedError("must be implemented")
# toCPU converts the RawBuffer to a numpy array with shape (size,). many backends are 0 copy here
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
# RawNumpyBuffer is a RawBuffer example for numpy. It's very simple
class RawNumpyBuffer(RawBuffer):
# NOTE: the "np.ndarray" is stored in the opaque container
def __init__(self, buf:np.ndarray):
super().__init__(buf.size, dtypes.from_np(buf.dtype), buf)
@classmethod
def fromCPU(cls, x): return cls(x)
def toCPU(self): return self._buf
# toCPU converts the RawBuffer to a numpy array with shape (size,)
def toCPU(self) -> np.ndarray: pass
# %%
# == Example: 2+3 in raw clang ==
# RawMallocBuffer is the simplest concrete version of RawBuffer (in tinygrad/ops.py)
# MallocAllocator is the simplest concrete version of Allocator (in tinygrad/device.py)
# it's used for the CLANG and LLVM backends
# it's just malloc(size * dtype.itemsize)
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.device import MallocAllocator
# ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10)
# __init__ calls clang, and __call__ calls the function in the *.so outputted by clang
@ -224,16 +205,21 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
# then we copy the numpy in to RawMallocBuffers
# last, we create an empty output buffer
from tinygrad.helpers import dtypes
input_a, input_b = MallocAllocator.alloc(1, dtypes.float32), MallocAllocator.alloc(1, dtypes.float32)
output = MallocAllocator.alloc(1, dtypes.float32)
# now we copy in the values
numpy_a, numpy_b = np.array([2], dtype=np.float32), np.array([3], dtype=np.float32)
input_a, input_b = RawMallocBuffer.fromCPU(numpy_a), RawMallocBuffer.fromCPU(numpy_b)
output = RawMallocBuffer(1, dtypes.float32)
MallocAllocator.copyin(input_a, numpy_a.data.cast("B"))
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"), bufs=3)
program(output, input_a, input_b)
print(output.toCPU())
assert output.toCPU()[0] == 5, "it's still 5"
np.testing.assert_allclose(output.toCPU(), numpy_a+numpy_b)
numpy_out = np.empty(1, dtype=np.float32)
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
assert numpy_out[0] == 5, "it's still 5"
np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
# %%
# == Linearizer (in tinygrad/codegen/linearizer.py, code 4/10) ==

View File

@ -5,6 +5,7 @@ from tinygrad.helpers import getenv, dtypes
if __name__ == "__main__":
if getenv("DIST"):
dist.preinit()
from extra.dist import collectives
# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
@ -22,7 +23,6 @@ from tinygrad.helpers import GlobalCounters
from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit
from extra.dist import collectives
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)

View File

@ -17,9 +17,9 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
key = id(arg)
if key not in bufs:
if key in special_names:
bufs[key] = (special_names[key], arg._memsz, arg.dtype, key)
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
else:
bufs[key] = (f"buf_{bufnum}", arg._memsz, arg.dtype, key)
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])

View File

@ -6,8 +6,8 @@ from typing import Any, Dict, List, Tuple
from dataclasses import dataclass
try:
_libhip = ctypes.cdll.LoadLibrary("libamdhip64.so")
_libhiprtc = ctypes.cdll.LoadLibrary("libhiprtc.so")
_libhip = ctypes.cdll.LoadLibrary("/opt/rocm/lib/libamdhip64.so")
_libhiprtc = ctypes.cdll.LoadLibrary("/opt/rocm/lib/libhiprtc.so")
_libhip.hipGetErrorString.restype = ctypes.c_char_p
_libhip.hipGetErrorString.argtypes = [ctypes.c_int]

View File

@ -5,10 +5,12 @@ import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.device import Device
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL, OSX_TIMING_RATIO
from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO
CL = Device["GPU"]
DEBUGCL = getenv("DEBUGCL", 0)
FLOAT16 = getenv("FLOAT16", 0)
@ -74,29 +76,29 @@ class Thneed:
if o['arg_type'] == "image2d_t":
if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
# hack: use a image1d since we can back that with a buffer
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
# buffer isn't supported in image2d, copy buffer into image
if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
cl.enqueue_copy(CL.cl_queue[0], arr, bufs[o['buffer_id']])
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']])
buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
elif o['needs_load']:
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
else:
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
if o['arg_type'] == "image1d_t":
assert not o['needs_load']
assert not bufs_loaded[o['buffer_id']]
buf = cl.Image(CL.cl_ctxs[0], mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
else:
if 'data' in o:
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
else:
# zero out buffers
buf = cl.Buffer(CL.cl_ctxs[0], mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
bufs[o['id']] = buf
bufs_loaded[o['id']] = 'data' in o
@ -108,7 +110,7 @@ class Thneed:
prgs = {}
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr])
ptr = nptr
# populate the cl_cache
@ -153,7 +155,7 @@ class Thneed:
for prg, args in self.cl_cache:
# get binaries for saving
if prg.name not in saved_binaries:
binary = prg.clprograms[0].get_info(cl.program_info.BINARIES)
binary = prg.clprogram.get_info(cl.program_info.BINARIES)
assert len(binary) == 1
jdat['binaries'].append({"name":prg.name, "length":len(binary[0])})
binaries.append(binary[0])
@ -161,7 +163,7 @@ class Thneed:
# get the args from the kernel, some need the data saved
targs, args_size = [], []
argdtypes = prg.argdtypes if prg.argdtypes is not None else [None]*(len(args)-2)
argdtypes = [None]*(len(args)-2)
for a,d in zip(args[2:], argdtypes):
if d == np.int16:
targs.append(struct.pack("H", a).decode("latin_1"))
@ -185,7 +187,7 @@ class Thneed:
})
if needs_load:
data = np.empty(a.size//4, dtype=np.float32)
cl.enqueue_copy(CL.cl_queue[0], data, a, is_blocking=True)
cl.enqueue_copy(CL.queue, data, a, is_blocking=True)
weights.append(data.tobytes())
elif isinstance(a, cl.Image):
assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type"
@ -193,12 +195,12 @@ class Thneed:
row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64
size = row_pitch * a.shape[1]
# this is *2 if float16 and *4 if float32
buf = cl.Buffer(CL.cl_ctxs[0], cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
buf = cl.Buffer(CL.ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1))
# zero out the buffer
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)
cl.enqueue_copy(CL.queue, buf, b'\x00'*buf.size, is_blocking=True)
CLProgram("from_image_strided", compile_gpu("""
CLProgram(CL, "from_image_strided", compile_gpu("""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
@ -206,7 +208,7 @@ class Thneed:
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
"""), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
"""), bufs=2, vars=1)(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)
# multiple of 32 isn't enough
jdat['objects'].append({
@ -216,7 +218,7 @@ class Thneed:
if needs_load:
data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32)
cl.enqueue_copy(CL.cl_queue[0], data, buf, is_blocking=True)
cl.enqueue_copy(CL.queue, data, buf, is_blocking=True)
if FLOAT16: data = data.astype(np.float16)
weights.append(data.tobytes())
else:
@ -263,9 +265,9 @@ class Thneed:
events = []
st = time.monotonic()
for prg, args in self.cl_cache:
events.append(prg.clprgs[0](CL.cl_queue[0], *args))
events.append(prg.clprg(CL.queue, *args))
mt = time.monotonic()
CL.synchronize()
Device["GPU"].synchronize()
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")

View File

@ -85,7 +85,6 @@ def schedule_to_thneed(schedule, output_fn):
def thneed_test_onnx(onnx_data, output_fn):
import onnx
import pyopencl as cl
from tinygrad.runtime.ops_gpu import CL
import numpy as np
from extra.thneed import Thneed
onnx_model = onnx.load(io.BytesIO(onnx_data))
@ -118,11 +117,11 @@ def thneed_test_onnx(onnx_data, output_fn):
# inputs
for k,v in nt.inputs.items():
cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True)
cl.enqueue_copy(Device["GPU"].queue, v, new_np_inputs[k], is_blocking=True)
nt.run()
new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(new_torch_out.shape)
cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True)
cl.enqueue_copy(Device["GPU"].queue, new_thneed_out, nt.outputs[0], is_blocking=True)
# compare torch to thneed
np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2)

View File

@ -1,125 +0,0 @@
#!/usr/bin/env python
import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad import Device
from test.helpers import derandomize_model
from examples.llama import Transformer
ALLOCATED_DEV_BUFS = 0
class FakeDeviceBuffer:
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
global ALLOCATED_DEV_BUFS
ALLOCATED_DEV_BUFS += 1
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"
def __del__(self): # Fake allocator should clear all buffers after each test.
for v in self.cached_buffers.values():
for buf, _ in v: self._free_buffer(buf)
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, *bufs, global_size, local_size, wait=False): pass
def helper_test_correctness(gen, train):
from tinygrad.runtime.ops_gpu import CL, CLAllocator
old_alloc = CL.cl_allocator
CL.cl_allocator = CLAllocator(0)
no_alloc_result = train(*gen()).numpy()
Device[Device.DEFAULT].synchronize()
CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb
for _ in range(4):
GlobalCounters.reset()
np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5)
Device[Device.DEFAULT].synchronize()
assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used"
CL.cl_allocator = old_alloc
def __helper_test_alloc_count(gen, train):
was_alloc = ALLOCATED_DEV_BUFS
for _ in range(2):
train(*gen())
return ALLOCATED_DEV_BUFS - was_alloc
def helper_test_alloc_count(mm, gen, train):
global FAKE_GLOBAL_ALLOCATOR
backup_program = Device[Device.DEFAULT].runtime
backup_buffer = Device[Device.DEFAULT].buffer
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].buffer = FakeBuffer
Device[Device.DEFAULT].get_runner.cache_clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30)
new_allocs = __helper_test_alloc_count(gen, train)
Device[Device.DEFAULT].get_runner.cache_clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0)
old_allocs = __helper_test_alloc_count(gen, train)
print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}")
assert new_allocs < old_allocs, "Hmm, doesn't cache work any more?"
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].buffer = backup_buffer
FAKE_GLOBAL_ALLOCATOR = None
def check_gc():
if Device.DEFAULT == "GPU":
gc.collect() # Need to collect Tensors.
from extra.introspection import print_objects
assert print_objects() == 0
class TestAllocators(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama(self):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def __test():
model = Transformer(**args_tiny)
derandomize_model(model)
def test(t): return model(t, 0).realize()
helper_test_correctness(lambda: (Tensor([[1,]]),), test)
__test()
Tensor.default_type = old_type
check_gc()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama_alloc_counts(self):
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def test_alloc_count(t):
model = Transformer(**args_tiny)
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
return model(t, 0).realize()
helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count)
check_gc()
@unittest.skip("huge for CI")
def test_stable_diffusion(self):
from examples.stable_diffusion import UNetModel
model = UNetModel()
derandomize_model(model)
def test(t, t2): return model(t, 801, t2).realize()
helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test)
if __name__ == "__main__":
unittest.main()

View File

@ -1,29 +1,27 @@
# NOTE: this only tests the speed of the LLaMA codegen, it doesn't actually run the net
import unittest, time
import numpy as np
from examples.llama import Transformer, MODEL_PARAMS
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.device import Compiled
from tinygrad.device import Compiled, Allocator
from tinygrad.helpers import Profiling
from tinygrad.runtime.lib import RawBuffer
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0): pass
def __call__(self, *bufs, global_size, local_size, wait=False): pass
class RawFakeBuffer(RawBuffer):
def _copyin(self, x:np.ndarray): pass
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
class FakeAllocator(Allocator):
def _alloc(self, sz, dtype): return None
def copyin(self, dest, src:memoryview): pass
class TestLLaMASpeed(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
def test_llama_compile(self):
backup_program = Device[Device.DEFAULT].runtime
backup_buffer = Device[Device.DEFAULT].buffer
backup_allocator = Device[Device.DEFAULT].allocator
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].buffer = RawFakeBuffer
Device[Device.DEFAULT].allocator = FakeAllocator()
print("testing llama python run time")
model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"])
@ -48,7 +46,7 @@ class TestLLaMASpeed(unittest.TestCase):
run_llama("profile")
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].buffer = backup_buffer
Device[Device.DEFAULT].allocator = backup_allocator
if __name__ == '__main__':
unittest.main()

View File

@ -6,7 +6,7 @@ from tinygrad.nn.state import get_parameters
def derandomize(x):
if isinstance(x, LazyOp):
new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), x.arg)
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.RAND else x.arg)
x.op = derandomize(x.op)
return x

View File

@ -1,188 +0,0 @@
#!/usr/bin/env python
import unittest
import pytest
import numpy as np
from weakref import ref
from tinygrad.helpers import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad import Device
from tinygrad.tensor import Tensor
def check_gc():
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0
class FakeDeviceBuffer:
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
def __del__(self):
assert self.id == 0, "Should called _do_free() before"
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if size*dtype.itemsize > self._get_cur_free_space(device): raise Exception("OOM")
return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"
def __del__(self): # Fake allocator should clear all buffers after each test.
for v in self.cached_buffers.values():
for buf, _ in v: self._free_buffer(buf)
FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
def alloc(allocator, size, dtype, **kwargs):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = allocator
buf = FakeBuffer(size, dtype, **kwargs)
assert buf.dtype == dtype and buf.size == size
FAKE_GLOBAL_ALLOCATOR = None
return buf
def alloc_free_trace(allocator, size, dtype, **kwargs):
buf = alloc(allocator, size, dtype, **kwargs)
return ref(buf._buf)
def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf
class TestAllocators(unittest.TestCase):
def test_lru_allocator_reusage(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(2048)
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
assert GlobalCounters.mem_cached - mc == 16*dtypes.float32.itemsize, "Buffer should be cached"
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused"
__test()
usedbuf = alloc(lru_allocator, 16, dtypes.float32)
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert usedbuf != buf, "Nobody should get used buffer"
__test()
assert GlobalCounters.mem_used - mu == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
test()
check_gc()
def test_lru_allocator_cache_free(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(128)
refs = []
for _ in range(32):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
for sz in range(1, 32):
alloc_free_trace(lru_allocator, sz, dtypes.float32)
assert GlobalCounters.mem_used + GlobalCounters.mem_cached - mc - mu <= 128, "Should not allocate on device more than allowed (128)"
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
test()
check_gc()
def test_lru_allocator_multidevice(self):
def test():
lru_allocator = FakeAllocator(256)
refs=[]
for i in range(8):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i)))
for i in range(64):
def __test():
dev = str(i % 8)
buf = alloc(lru_allocator, 16, dtypes.float32, device=dev)
assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused"
__test()
for r in refs: assert r() is not None, "All refs should be cached"
test()
check_gc()
def test_lru_allocator_failing_alloc_cleans_cache(self):
def test():
lru_allocator = FakeAllocator(128)
for size in range(1, 4):
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) == 3, "All buffers should be cached"
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
def always_raise_exception(*args, **kwargs):
raise MemoryError("OOM")
lru_allocator._do_alloc = always_raise_exception
with pytest.raises(Exception):
alloc(lru_allocator, 5, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) == 0, "All buffers should be freed from cache due to failing alloc"
test()
check_gc()
def test_lru_allocator_fail_first_alloc_pass_after_clear_cahce(self):
def test():
lru_allocator = FakeAllocator(128)
for size in range(1, 4):
alloc_free_trace(lru_allocator, size, dtypes.float32, device='0')
cache_length = 3
assert len(lru_allocator.aging_order['0']) == cache_length, "All buffers should be cached"
assert lru_allocator.free_space['0'] == 128 - 24, "24 bytes to be used by current cached buffers"
original_do_alloc = lru_allocator._do_alloc # save the original method
def single_fail_then_pass(*args, **kwargs):
lru_allocator._do_alloc = original_do_alloc # restore the original method
raise MemoryError("OOM")
lru_allocator._do_alloc = single_fail_then_pass
alloc(lru_allocator, 5, dtypes.float32, device='0')
assert len(lru_allocator.aging_order['0']) < cache_length, "Some buffers should be cleaned as first alloc failed"
test()
check_gc()
@unittest.skip("failing in CI")
def test_gpu_copyout(self):
def test():
from tinygrad.runtime.ops_gpu import CL
# Allocation to init the allocator.
tx = Tensor.rand(1)
tx.realize()
free_space = CL.cl_allocator.free_space[tx.lazydata.realized._device]
# Spawning 128mb objects to fill half of free_space
will_allocate = free_space // 3
trash_allocation_size = free_space // 2
def sp():
trash_buffer = Tensor.rand(trash_allocation_size // 4)
trash_buffer.realize()
sp()
xx = Tensor.rand(will_allocate // 4)
_ = xx.numpy()
test()
check_gc()
def test_lru_allocator_massive_buffer(self):
with self.assertRaises(AssertionError) as context: alloc(allocator := FakeAllocator(), size := 1e13, dtypes.int8)
self.assertEqual(str(context.exception), f"out of memory - requested: {size/1e9:5.2f} GB, available: {allocator._get_cur_free_space('0')/1e9:5.2f} GB")
@unittest.skipIf(Device.DEFAULT != "METAL", "only applies to Metal")
def test_lru_allocator_metal_max_buffer_length(self):
from tinygrad.runtime.ops_metal import METAL
with self.assertRaises(AssertionError) as context: METAL.allocator._do_alloc(buf_len := (max_buf_len := METAL.device.maxBufferLength()+1), dtypes.int8, '0')
self.assertEqual(str(context.exception), f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB.")
if __name__ == "__main__":
unittest.main()

View File

@ -8,7 +8,7 @@ from tinygrad.helpers import prod, dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import LazyBuffer, create_lazybuffer
from tinygrad.lazy import Buffer, create_lazybuffer
from tinygrad.device import CompiledASTRunner, Device
from tinygrad.shape.shapetracker import ShapeTracker
import pytest
@ -16,17 +16,15 @@ import pytest
pytestmark = pytest.mark.webgpu
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
CompiledASTRunner(None, "atan2_gpu", """
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
}""", global_size=[ret.size], bufcount=3).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret, a, b])
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): ret.realized._copyin(np.arctan2(a.realized._buf, b.realized._buf))
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
# *** second, we write the ATan2 mlop ***
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative

View File

@ -7,6 +7,7 @@ from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
class TestLazyBuffer(unittest.TestCase):
@unittest.skip("it doesn't work like this anymore")
def test_fromcpu_buffer_sharing(self):
a = np.arange(8)
assert LazyBuffer.fromCPU(a).realized._buf is a

View File

@ -3,7 +3,7 @@ import unittest, os
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps
from tinygrad.device import Compiled, Device
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@ -140,7 +140,7 @@ def helper_realized_ast(r:Tensor):
s = r.lazydata.schedule()
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
output_buffer = Device[s[-1].out.device].buffer(prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype, **s[-1].out._device_extra_args()) # allocate an output buffer
output_buffer = Buffer(s[-1].out.device, prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype, **s[-1].out._device_extra_args()) # allocate an output buffer
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
class TestFloat4(unittest.TestCase):
@ -367,7 +367,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
for opt in opts:
k.apply_opt(opt)
prg = to_prg(k)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
@ -381,7 +381,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False):
k = Linearizer(realized_ast)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
prg.exec(real_bufs)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.

View File

@ -2,7 +2,7 @@ import unittest
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer
from tinygrad.device import Compiled, Device
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import LoadOps
from tinygrad.tensor import Tensor
@ -12,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
rawbufs = [Device[Device.DEFAULT].buffer(si.out.st.size(), si.out.dtype)] + [Device[Device.DEFAULT].buffer(x.st.size(), x.dtype) for x in si.inputs]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')

View File

@ -2,16 +2,16 @@ from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
from tinygrad.tensor import Device
from tinygrad.device import Buffer, Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
def _uops_to_prg(uops):
def _uops_to_prg(uops, bufcount):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return CompiledASTRunner(None, "test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
runtime_args=runtime_args, bufcount=bufcount).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))
@ -24,9 +24,9 @@ def _test_single_value(vals, op, dtype):
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals)))
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Device[Device.DEFAULT].buffer(1, dtype)
buf2 = [Device[Device.DEFAULT].buffer.fromCPU(np.array([a], dtype=dtype.np)) for a in vals]
prg = _uops_to_prg(uops)
buf = Buffer(Device.DEFAULT, 1, dtype)
buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a in vals]
prg = _uops_to_prg(uops, 1+len(buf2))
prg.exec([buf]+buf2)
return buf.toCPU()[0]
@ -36,8 +36,8 @@ def _test_single_value_const(vals, op, dtype):
loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals)
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Device[Device.DEFAULT].buffer(1, dtype)
prg = _uops_to_prg(uops)
buf = Buffer(Device.DEFAULT, 1, dtype)
prg = _uops_to_prg(uops, 1)
prg.exec([buf])
return buf.toCPU()[0]

View File

@ -3,8 +3,7 @@ import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import dtypes, fetch, temp
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.helpers import fetch, temp
from tinygrad.helpers import Timing
def compare_weights_both(url):
@ -40,11 +39,6 @@ class TestRawDiskBuffer(unittest.TestCase):
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
f.readinto(tst)
def test_mmap_read_speed(self):
db = RawDiskBuffer(test_size, dtype=dtypes.uint8, device=test_fn)
tst = np.empty(test_size, np.uint8)
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
np.copyto(tst, db.toCPU())
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
class TestSafetensors(unittest.TestCase):
def test_real_safetensors(self):

View File

@ -1,8 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union, Type, Any, List, Optional, Dict, Callable
import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable, Tuple
import importlib, inspect, functools, pathlib, time, re
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name
from tinygrad.runtime.lib import RawBuffer
from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, Op
@ -16,9 +17,11 @@ class _Device:
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0]
def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]:
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0]
if isinstance(ret, type): ret = ret(ix)
return ret
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore
@ -30,18 +33,64 @@ class _Device:
return "CPU"
Device = _Device()
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None):
assert isinstance(dtype, DType)
self.device, self.size, self.dtype = device, size, dtype
self._buf = opaque if opaque is not None else Device[self.device].allocator.alloc(size, dtype)
GlobalCounters.mem_used += self.size * self.dtype.itemsize
def __del__(self):
GlobalCounters.mem_used -= self.size * self.dtype.itemsize
Device[self.device].allocator.free(self._buf, self.size, self.dtype)
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
def copyin(self, mv:memoryview):
mv = mv.cast("B", shape=[self.size*self.dtype.itemsize])
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
Device[self.device].allocator.copyin(self._buf, mv)
return self
@staticmethod
def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
def toCPU(self) -> np.ndarray:
ret = np.empty(self.size, self.dtype.np)
if self.size > 0: Device[self.device].allocator.copyout(ret.data.cast("B", shape=[self.size*self.dtype.itemsize]), self._buf)
return ret
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int, dtype:DType): return self._alloc(size, dtype)
def _alloc(self, size:int, dtype:DType): raise NotImplementedError("need alloc")
def free(self, opaque, size:int, dtype:DType): self._free(opaque) # if you are returning a Python object, you don't need a free
def _free(self, opaque): pass
def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
class LRUAllocator(Allocator): # pylint: disable=abstract-method
def __init__(self): self.cache: Dict[Tuple[int, DType], Any] = defaultdict(list)
def alloc(self, size:int, dtype:DType):
if len(c := self.cache[(size, dtype)]): return c.pop()
try:
return self._alloc(size, dtype)
except MemoryError:
self.free_cache()
return self._alloc(size, dtype)
def free_cache(self):
for opaques in self.cache.values():
for opaque in opaques: self._free(opaque)
opaques.clear()
def free(self, opaque:Any, size:int, dtype:DType): self.cache[(size, dtype)].append(opaque)
# **************** shared device helpers ****************
class JITRunner:
def __init__(self):
self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
@ -64,18 +113,16 @@ class InterpretedASTRunner(JITRunner):
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
st = time.perf_counter()
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {ret.size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
assert rawbufs[0].dtype == ret.dtype, f"dtype mismatch in Interpreted, {rawbufs[0].dtype=} != {ret.dtype=}"
rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
return et
class Interpreted:
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]):
self.buffer, self.fxn_for_op = buffer, fxn_for_op
def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
self.allocator, self.fxn_for_op = allocator, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@ -86,7 +133,6 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
@ -98,15 +144,16 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
tglob[ret] = x
return ret
lines: List[str] = []
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
# TODO: shortcutted store won't work with strides
if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op is BufferOps.STORE:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({_interpret_ast(ast.src[0])})"
elif ast.op in BufferOps:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
if ast.op in BufferOps:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"inputs[{ast.arg.idx}]"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
@ -124,16 +171,18 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
# **************** for Compiled Devices ****************
class CompiledASTRunner(JITRunner):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None, bufcount:int=0):
super().__init__()
if DEBUG >= 4: print(prg)
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}
self.bufcount = bufcount
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.bufcount = len(info.mem)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
@ -141,7 +190,7 @@ class CompiledASTRunner(JITRunner):
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib)
self.clprg = runtime(self.name, self.lib, self.bufcount, len(self.vars))
return self
def launch_dims(self, var_vals):
@ -149,7 +198,7 @@ class CompiledASTRunner(JITRunner):
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
return global_size, local_size
def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
@ -159,13 +208,13 @@ class CompiledASTRunner(JITRunner):
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2)
et = self.clprg(*[x._buf for x in rawbufs], *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph
def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None):
self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph
def to_program(self, k:Linearizer) -> CompiledASTRunner:
k.linearize()
@ -174,6 +223,7 @@ class Compiled:
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(_get_optimized_linearizer(self.linearizer_opts, ast))
def synchronize(self): pass
def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) -> Linearizer:
if DEBUG >= 3:
@ -196,4 +246,11 @@ def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) ->
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
return k
return k
import ctypes
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int, dtype:DType): return (ctypes.c_uint8 * (size*dtype.itemsize))()
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
MallocAllocator = _MallocAllocator()

View File

@ -1,11 +1,10 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time
from tinygrad.lazy import vars_from_ast
from tinygrad.device import Device, Compiled
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
from tinygrad.tensor import Tensor
@ -23,7 +22,7 @@ actions += [
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
@ -62,7 +61,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
if clear_l2:
# TODO: this is too small for many L2 caches
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor)
tms.append(prg.clprg(*[x._buf for x in rawbufs], *var_vals.values(), **lra, wait=True)*factor)
except Exception:
if DEBUG >= 4:
import traceback
@ -75,14 +74,14 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
return min(tms)
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
for x in lin.membufs: bufsts[x.idx].append(x)
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
rawbufs[k] = Buffer(Device.DEFAULT, prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
assert all(r is not None for r in rawbufs)
return cast(List[RawBuffer], rawbufs)
return cast(List[Buffer], rawbufs)
# get dictionary of all possible actions
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
@ -148,14 +147,14 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
test_rawbuffers = [type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return clprg(*test_rawbuffers, global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True)
except Exception:
return float('inf')
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
import numpy as np
from urllib import request
from tqdm import tqdm
@ -40,6 +40,10 @@ def partition(lst:List[T], fxn:Callable[[T],bool]):
def unwrap(x:Optional[T]) -> T:
assert x is not None
return x
def unwrap2(x):
ret, err = x
assert err is None, str(err)
return ret
def get_child(obj, key):
for k in key.split('.'):
if k.isnumeric(): obj = obj[int(k)]
@ -52,6 +56,7 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
@functools.lru_cache(maxsize=None)
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
def from_mv(mv, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
@ -251,3 +256,15 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=n
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
pathlib.Path(f.name).rename(fp)
return fp
# *** pretty PTX printer
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s

View File

@ -2,8 +2,7 @@ from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
from tinygrad.device import Device, JITRunner, CompiledASTRunner
from tinygrad.runtime.lib import RawBuffer
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node
@ -13,11 +12,11 @@ from dataclasses import dataclass
@dataclass(frozen=True)
class JitItem:
prg: JITRunner # or a graph executor like MetalGraph
rawbufs: List[Optional[RawBuffer]]
rawbufs: List[Optional[Buffer]]
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0))
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[RawBuffer]) -> Dict[Tuple[int, int], int]:
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
@ -55,7 +54,7 @@ class TinyJit(Generic[ReturnType]):
expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()])
# get rawbuffers
input_rawbuffers: List[RawBuffer] = [cast(RawBuffer, v.lazydata.realized) for v in input_tensors.values()]
input_rawbuffers: List[Buffer] = [cast(Buffer, v.lazydata.realized) for v in input_tensors.values()]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
@ -67,7 +66,7 @@ class TinyJit(Generic[ReturnType]):
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True)
elif self.cnt == 1:
# jit capture
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
@ -80,7 +79,7 @@ class TinyJit(Generic[ReturnType]):
# if your Device supports it, condense the items into a graph executor
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2:
try:
self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[RawBuffer]], input_rawbuffers))]
self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))]
except GraphException as e:
if DEBUG >= 1: print(f"graph create failed {e}")
@ -96,34 +95,34 @@ class TinyJit(Generic[ReturnType]):
return cast(ReturnType, self.ret)
class PlaceHolder:
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)
def to_tuple(self): return (self.size, self.dtype, self._device, self.buftype, self.bufid)
def __init__(self, buf:Buffer): self.size, self.dtype, self.device, self.ref, self.bufid = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf)
def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid)
def __hash__(self): return hash(self.to_tuple())
def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, RawBuffer]) -> RawBuffer:
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret: return ret
if self not in buffer_cache: buffer_cache[self] = self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype)
return buffer_cache[self]
class _CacheCollector:
def __init__(self):
self.cache: Optional[List[Tuple[JITRunner, List[Union[RawBuffer, PlaceHolder]]]]] = None
self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
self.placeholders: WeakKeyDictionary[RawBuffer, PlaceHolder] = WeakKeyDictionary()
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]:
if self.cache is None: return []
buffer_cache: Dict[PlaceHolder, RawBuffer] = {}
buffer_cache: Dict[PlaceHolder, Buffer] = {}
saved_cache, self.cache = self.cache, None
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector()

View File

@ -8,9 +8,7 @@ from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.device import Buffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@ -100,11 +98,11 @@ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2,
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[Buffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
self._realized: Optional[Buffer] = src
self.output_buffer: Optional[Buffer] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet[LazyBuffer] = WeakSet()
self.views: WeakSet[LazyBuffer] = WeakSet()
@ -211,7 +209,7 @@ class LazyBuffer:
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten()))
def cast(self, dtype:DType, bitcast:bool=False):
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))

View File

@ -96,7 +96,7 @@ class FlopCounter:
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {arg.idx: arg.dtype.itemsize*arg.st.size()}), UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},

View File

@ -1,10 +1,9 @@
from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps
from tinygrad.device import Device
from tinygrad.device import Device, Buffer
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv
from tinygrad.helpers import DEBUG, prod
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# NOTE: if you for loop the schedule it's slow because nothing frees
@ -27,53 +26,55 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Device[si.out.device].buffer(prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype, **si.out._device_extra_args())
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
# TODO: should this be handled here? it probably just shouldn't be in the schedule
if not hasattr(si.out.realized, 'size') or si.out.realized.size != 0:
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
#Device[si.out.device].buffer(prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype, **si.out._device_extra_args())
# TODO: size 0 should be removed from the schedule
if si.out.realized.size != 0:
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs)
else:
Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
# *** zero op LoadOps ***
def _realize_empty(buffer: LazyBuffer) -> None:
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
def _realize_empty(buffer: Buffer) -> None:
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {buffer.size:5d} dtype {buffer.dtype}")
# TODO: remove this and write the RNG in tinygrad
def _realize_rand(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "rand doesn't support symbolic shape"
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
rng = np.random.default_rng(buffer.op.arg)
buffer.realized._copyin(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
def _realize_rand(buffer: Buffer, arg) -> None:
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {arg:<10d} shape {buffer.size:5d} dtype {buffer.dtype}")
rng = np.random.default_rng(arg)
rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False)
buffer.copyin(rng_np_buffer.data)
# *** one op LoadOps ***
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
assert src.realized.size == buffer.realized.size, f"size mismatch on FROM {src.realized.size=} != {buffer.realized.size=}"
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
#from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
#from tinygrad.runtime.ops_disk import RawDiskBuffer
def _realize_from(buffer: Buffer, src: Buffer) -> None:
assert src.size == buffer.size, f"size mismatch on FROM {src.size=} != {buffer.size=}"
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.size:<16d} shape {buffer.size:5d} dtype {src.dtype}")
buffer.copyin(src.toCPU().data)
# TODO: make this generic
if isinstance(src.realized, RawDiskBuffer) and isinstance(buffer.realized, RawBufferMapped):
src.realized.readinto(buffer.realized._buffer())
elif isinstance(src.realized, RawBufferTransfer) and isinstance(buffer.realized, RawBufferTransfer) and getenv("P2P", 0) >= 1:
buffer.realized._transfer(src.realized)
else:
buffer.realized._copyin(src.realized.toCPU())
#if isinstance(src.realized, RawDiskBuffer) and isinstance(buffer.realized, RawBufferMapped):
# src.realized.readinto(buffer.realized._buffer())
#elif isinstance(src.realized, RawBufferTransfer) and isinstance(buffer.realized, RawBufferTransfer) and getenv("P2P", 0) >= 1:
# buffer.realized._transfer(src.realized)
#else:
#buffer.realized._copyin(src.realized.toCPU())
# *** n op LoadOps ***
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
buffer.op.arg(buffer, *inputs)
def _realize_custom(buffer: Buffer, *inputs: Buffer, arg) -> None:
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {buffer.size:5d} dtype {buffer.dtype}")
arg(buffer, *inputs)
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.EMPTY: _realize_empty,

View File

@ -0,0 +1,78 @@
from typing import List, Any, Dict, cast, Optional
import numpy as np
import Metal
from tinygrad.helpers import dtypes, dedup, unwrap2
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice
class MetalGraph:
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.device: MetalDevice = device
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals), dtypes.int32)
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
for i,b in enumerate(ji.rawbufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
self.command_buffer: Any = None
if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer
if wait:
command_buffer.waitUntilCompleted()
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
self.device.mtl_buffers_in_flight.append(command_buffer)
et = None
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et

View File

@ -1,105 +0,0 @@
from __future__ import annotations
import ctypes
import numpy as np
from collections import defaultdict, deque
from typing import Any, Dict, Deque, Tuple
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType
class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self.offset: int = 0 # TODO: this is very unsupported, only in disk
self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator
self._device = kwargs.get('device', None)
GlobalCounters.mem_used += self._memsz
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs):
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
if x.size > 0: ret._copyin(x)
return ret
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawBufferMapped(RawBuffer):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size)
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
# this one is simple enough that i moved it out of the runtimes
ctypes_map = {dtypes.float64:ctypes.c_double, dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.bfloat16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.uint32: ctypes.c_uint32, dtypes.int64: ctypes.c_int64, dtypes.uint64: ctypes.c_uint64, dtypes.int16: ctypes.c_int16, dtypes.uint16: ctypes.c_uint16}
class RawMallocBuffer(RawBufferMapped):
def __init__(self, size, dtype: DType): super().__init__(size, dtype, (ctypes_map[dtype] * size)())
def _buffer(self): return memoryview(self._buf)
class RawBufferCopyInOut(RawBuffer):
def _copyout(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray:
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
if x.size > 0: self._copyout(x)
return x
class RawBufferTransfer(RawBuffer):
def _transfer(self, x:RawBuffer) -> None: raise NotImplementedError("must be implemented")
class LRUAllocator:
def __init__(self, dev_memsz=(4<<30)):
self.epoch = 0
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
return rawbufs.popleft()[0]
def ensure_has_free_space(self, space_to_free, device):
while len(self.aging_order[device]) and self._get_cur_free_space(device) < space_to_free: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
assert (curr_free := self._get_cur_free_space(device)) > space_to_free, f"out of memory - requested: {space_to_free/1e9:5.2f} GB, available: {curr_free/1e9:5.2f} GB"
def _alloc_buffer(self, size, dtype, device, **kwargs):
self.ensure_has_free_space(size*dtype.itemsize, device)
while True:
try:
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
break
except Exception:
if len(self.aging_order[device]) == 0: raise
self.ensure_has_free_space(1.1*self._get_cur_free_space(device), device) # increase free space by 10% and try again.
self.free_space[device] -= size*dtype.itemsize
self.buffer_info[newbuf] = (size, dtype, device)
return newbuf
def _free_buffer(self, buf_to_free):
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
self.buffer_info.pop(buf_to_free)
self._do_free(buf_to_free)
def __call__(self, size, dtype, device='0', **kwargs): # allocate
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
self.epoch += 1
size, dtype, device = self.buffer_info[buf]
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
def _do_free(self, buf): pass
def _get_cur_free_space(self, device): return self.free_space[device]

View File

@ -1,8 +1,7 @@
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any
from tinygrad.device import Compiled
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.helpers import diskcache
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
@ -21,7 +20,7 @@ def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes:
return pathlib.Path(output_file.name).read_bytes()
class ClangProgram:
def __init__(self, name:str, prg:bytes):
def __init__(self, name:str, prg:bytes, bufs:int, vars:int=0):
# write to disk so we can load it
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(prg)
@ -29,8 +28,8 @@ class ClangProgram:
def __call__(self, *args, wait=False):
if wait: st = time.perf_counter()
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
self.fxn(*args)
if wait: return time.perf_counter()-st
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangDevice = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)

View File

@ -1,14 +1,8 @@
import numpy as np
from typing import Callable, Dict, Tuple, Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op
from tinygrad.device import Interpreted
from tinygrad.runtime.lib import RawBuffer
class RawNumpyBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf)
def _copyin(self, x): self.size, self.dtype, self._buf = x.size, dtypes.from_np(x.dtype), x
def toCPU(self): return self._buf if self._buf is not None else np.empty([self.size], self.dtype.np)
from tinygrad.device import Interpreted, Allocator
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
@ -31,7 +25,7 @@ def einsum_mulacc(einsum, get_strides, expand):
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.LOAD: lambda x: x.toCPU(), BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), BufferOps.STORE: RawNumpyBuffer.fromCPU,
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
@ -45,4 +39,9 @@ numpy_fxn_for_op: Dict[Op, Callable] = {
TernaryOps.WHERE: np.where,
}
CPUDevice = Interpreted(RawNumpyBuffer, numpy_fxn_for_op)
class NumpyAllocator(Allocator):
def _alloc(self, size:int, dtype:DType): return np.empty(size, dtype.np)
def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)

View File

@ -1,26 +1,17 @@
import subprocess, time, re, hashlib, tempfile
import subprocess, time, hashlib, tempfile
from pathlib import Path
from typing import Optional, Tuple
from typing import Tuple
import numpy as np
from pycuda.compiler import compile as cuda_compile
from tinygrad.helpers import DEBUG, getenv, colored, diskcache
from tinygrad.device import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.helpers import DEBUG, getenv, pretty_ptx, diskcache
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cuda import CUDARenderer
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
return s
def arch(): return "sm_" + "".join([str(x) for x in pycuda.driver.Context.get_device().compute_capability()])
CUDACPU = getenv("CUDACPU") == 1
if getenv("CUDACPU", 0) == 1:
if CUDACPU:
import ctypes, ctypes.util
lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
@ -44,26 +35,21 @@ if getenv("CUDACPU", 0) == 1:
get_device = lambda: context.device # pylint: disable=unnecessary-lambda # noqa: E731
import pycuda.driver
pycuda.driver.Context = context
RawCUDABuffer = RawMallocBuffer
else:
import pycuda.autoprimaryctx # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
class CUDAAllocator(LRUAllocator):
def __init__(self): super().__init__(self._get_cur_free_space(None))
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
def _get_cur_free_space(self, device): return cuda.mem_get_info()[0] # type: ignore
CUDAAlloc = CUDAAllocator()
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore
def _alloc(self, size, dtype):
if size == 0: return None
return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
def copyin(self, dest, src:memoryview): cuda.memcpy_htod_async(dest, src) # type: ignore
def copyout(self, dest:memoryview, src): cuda.memcpy_dtoh(dest, src) # type: ignore
@diskcache
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])
class CUDAProgram:
def __init__(self, name:str, _prg:bytes):
def __init__(self, name:str, _prg:bytes, bufs:int, vars:int=0):
prg = _prg.decode('utf-8')
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
@ -80,11 +66,15 @@ class CUDAProgram:
if wait:
start, end = cuda.Event(), cuda.Event()
start.record()
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=shared)
self.prg(*[np.int32(x) if (isinstance(x, int) and not CUDACPU) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=shared)
if wait:
end.record()
end.synchronize()
return start.time_till(end)*1e-3
CUDADevice = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]),
CUDARenderer, compile_cuda, CUDAProgram, cuda.Context.synchronize)
class CUDADevice(Compiled):
def __init__(self, device:str):
super().__init__(MallocAllocator if CUDACPU else CUDAAllocator(),
LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]),
CUDARenderer, compile_cuda, CUDAProgram)
def synchronize(self): return cuda.Context.synchronize()

View File

@ -1,54 +1,56 @@
import os, mmap
try: import _posixshmem
except Exception: pass
from typing import Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType, OSX
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.device import Interpreted
from tinygrad.ops import Op, MovementOps, UnaryOps, BufferOps
from tinygrad.device import Interpreted, Allocator
from tinygrad.ops import Op, MovementOps, UnaryOps
from tinygrad.shape.view import strides_for_shape
MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000
class UnderlyingDiskBuffer:
def __init__(self, fd, mem): self.fd, self.mem = fd, mem
def __del__(self):
if self.fd: self.fd.close()
class RawDiskBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, buf=None, device:Optional[str]=None, offset:int=0): # pylint: disable=super-init-not-called
assert device is not None or buf is not None, "disk tensor needs a path or a buf"
if device is not None:
if str(device).startswith("shm:"):
if OSX:
with open(f"/tmp/shm_{device[4:]}", "w+b") as f:
f.truncate(size * dtype.itemsize)
shm = mmap.mmap(f.fileno(), size * dtype.itemsize, flags=mmap.MAP_SHARED)
else:
fd = _posixshmem.shm_open(device[4:], os.O_RDWR, 0o600)
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE)
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
os.close(fd)
buf = UnderlyingDiskBuffer(None, shm)
else:
f = open(device, "a+b")
if os.path.getsize(device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size * dtype.itemsize))
# NOTE: we don't call super since disk tensors don't use RAM
self.size, self.dtype, self._buf, self.offset = size, dtype, buf, offset
def cast(self, arg:Tuple[DType, bool]):
return RawDiskBuffer(self.size, arg[0], self._buf, offset=self.offset)
class DiskBuffer:
def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType, offset=0): self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset
def __repr__(self): return f"<DiskBuffer size={self.size} dtype={self.dtype} offset={self.offset}>"
def cast(self, arg:Tuple[DType, bool]): return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset)
def as_strided(self, arg):
assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides"
return RawDiskBuffer(prod(arg[0]), self.dtype, self._buf, offset=self.offset+arg[2]*self.dtype.itemsize)
def _buffer(self): return memoryview(self._buf.mem)[self.offset:self.offset+self.size*self.dtype.itemsize]
def readinto(self, buf:memoryview):
if self._buf.fd is not None:
self._buf.fd.seek(self.offset)
self._buf.fd.readinto(buf)
else:
buf.cast('B')[:] = self._buffer()
return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize)
def _buf(self) -> memoryview: return memoryview(self.ud.mem).cast("B")[self.offset:self.offset+self.size*self.dtype.itemsize]
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.LOAD: lambda x: x, BufferOps.STORE: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskDevice = Interpreted(RawDiskBuffer, disk_fxn_for_op)
disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided }
MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000
class DiskAllocator(Allocator):
def __init__(self, device): self.device = device
def _alloc(self, size, dtype):
if str(self.device).startswith("shm:"):
if OSX:
with open(f"/tmp/shm_{self.device[4:]}", "w+b") as f:
f.truncate(size * dtype.itemsize)
shm = mmap.mmap(f.fileno(), size * dtype.itemsize, flags=mmap.MAP_SHARED)
else:
fd = _posixshmem.shm_open(self.device[4:], os.O_RDWR, 0o600)
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE)
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
os.close(fd)
buf = UnderlyingDiskBuffer(None, shm)
else:
f = open(self.device, "a+b")
if os.path.getsize(self.device) < size * dtype.itemsize: os.ftruncate(f.fileno(), size * dtype.itemsize)
buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size * dtype.itemsize))
return DiskBuffer(buf, size, dtype)
def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
def copyout(self, dest:memoryview, src:DiskBuffer):
if src.ud.fd is not None:
src.ud.fd.seek(src.offset)
src.ud.fd.readinto(dest)
else:
dest[:] = src._buf()
class DiskDevice(Interpreted):
def __init__(self, device): super().__init__(DiskAllocator(device[5:]), disk_fxn_for_op)

View File

@ -1,105 +1,50 @@
from __future__ import annotations
import os
os.environ['PYOPENCL_NO_CACHE'] = '1'
import pathlib
import pathlib, functools
import numpy as np
import pyopencl as cl
from typing import Optional, List, Tuple
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
from tinygrad.device import Compiled
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache, DType
from tinygrad.device import Compiled, LRUAllocator
from tinygrad.renderer.opencl import OpenCLRenderer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.kernel import LinearizerOptions
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
if DEBUG >= 5:
if DEBUG >= 6:
early_exec = fromimport("extra.helpers", "enable_early_exec")()
class CLAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
if isinstance(dtype, ImageDType):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
else:
buf = cl.Buffer(CL.cl_ctxs[int(device)], cl.mem_flags.READ_WRITE, size * dtype.itemsize)
setattr(buf, 'device', int(device)) # device is tracked on the underlying buffer
return buf
class _CL:
def __init__(self):
cl_platforms = cl.get_platforms()
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if y]
self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
self.cl_platform = self.devices[0].platform
def post_init(self, device=None):
self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
def synchronize(self):
for q in self.cl_queue: q.finish()
CL = _CL()
if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
def _clear_event(self, _): del self.event
def _copyin(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False)
self.event.set_callback(cl.command_execution_status.COMPLETE, self._clear_event)
def _copyout(self, x:np.ndarray):
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
CL.cl_allocator.ensure_has_free_space(self.size*self.dtype.itemsize, self._device)
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([evt] if (evt:=getattr(self, "event", None)) else []))
def _transfer(self, x):
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
@diskcache
def compile_gpu(prg:str) -> bytes:
clprg = cl.Program(CL.cl_ctxs[0], prg)
clprg = cl.Program(GPUDevice.compile_context, prg)
clprg.build()
return clprg.get_info(cl.program_info.BINARIES)[0]
class CLProgram:
def __init__(self, name:str, prg:bytes, argdtypes=None, options=None):
self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs]
self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms]
self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs]
def __init__(self, device:GPUDevice, name:str, prg:bytes, bufs:int=0, vars:int=0):
self.device, self.name, self.clprogram = device, name, cl.Program(device.ctx, [device.ctx.devices[0]], [prg])
self.clprogram.build()
self.clprg = self.clprogram.__getattr__(name)
if DEBUG >= 5 and not OSX:
if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
device_name = self.device.ctx.devices[0].name
if 'Adreno' in device_name:
fromimport('disassemblers.adreno', 'disasm')(prg)
elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
elif device_name.startswith('gfx'):
asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
elif "NVIDIA" in CL.cl_ctxs[0].devices[0].name:
elif "NVIDIA" in device_name:
# print the PTX for NVIDIA.
print(prg.decode('utf-8'))
if argdtypes is not None: self.set_argdtypes(argdtypes)
def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
if vars > 0: self.clprg.set_scalar_arg_dtypes([None]*bufs + [np.int32]*vars)
@staticmethod
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def max_work_group_size(): return GPUDevice.compile_context.devices[0].max_work_group_size if GPUDevice.compile_context is not None else 1024
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]:
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
cl_bufs, wait_for = [], []
for x in bufs:
if x.__class__ is CLBuffer:
cl_bufs.append(x._buf)
if (event:=getattr(x, "event",None)): wait_for.append(event)
else: cl_bufs.append(x)
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
e = self.clprg(self.device.queue, [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *bufs)
if wait:
e.wait()
try:
@ -108,4 +53,38 @@ class CLProgram:
return None
return None
GPUDevice = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize)
class CLAllocator(LRUAllocator):
def __init__(self, device:GPUDevice):
self.events: List[cl.Event] = []
self.device = device
super().__init__()
def _alloc(self, size:int, dtype:DType):
if size == 0: return None
if isinstance(dtype, ImageDType):
# NOTE: the memory is a bit off here due to padding, it's buf.row_pitch * buf.height * 4 * dtype.itemsize
assert size == prod(dtype.shape), f"image size mismatch {size} != {dtype.shape}"
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
buf = cl.Image(self.device.ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
else:
buf = cl.Buffer(self.device.ctx, cl.mem_flags.READ_WRITE, size * dtype.itemsize)
return buf
def copyin(self, dest:cl.Buffer, src:memoryview): self.events.append(cl.enqueue_copy(self.device.queue, dest, src, is_blocking=False))
def copyout(self, dest:memoryview, src:cl.Buffer):
self.events.clear()
cl.enqueue_copy(self.device.queue, dest, src, is_blocking=True)
class GPUDevice(Compiled):
devices = None
compile_context = None
def __init__(self, device:str):
if GPUDevice.devices is None:
cl_platforms = cl.get_platforms()
platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if y]
GPUDevice.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
if DEBUG >= 1: print(f"using devices: {[device.hashable_model_and_version_identifier for device in GPUDevice.devices]}")
self.device = int(device.split(":")[1]) if ":" in device else 0
self.ctx = cl.Context(devices=[GPUDevice.devices[self.device]])
if GPUDevice.compile_context is None: GPUDevice.compile_context = self.ctx
self.queue = cl.CommandQueue(self.ctx, device=self.ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE)
super().__init__(CLAllocator(self), LinearizerOptions(), OpenCLRenderer, compile_gpu, functools.partial(CLProgram, self))
def synchronize(self): self.queue.finish()

View File

@ -1,14 +1,10 @@
import numpy as np
import ctypes
import ctypes, functools
import extra.hip_wrapper as hip
from typing import Tuple, List, Any, Dict, cast, Optional, Callable
from tinygrad.helpers import DEBUG, getenv, diskcache
from tinygrad.device import Compiled, CompiledASTRunner, update_stats
from typing import Tuple, cast, Callable, TypeVar
from tinygrad.helpers import DEBUG, DType, getenv, diskcache, from_mv
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator
from tinygrad.renderer.hip import HIPRenderer
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer, RawBuffer, RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
if DEBUG >= 6:
@ -16,41 +12,12 @@ if DEBUG >= 6:
early_exec = enable_early_exec()
# The default HIP stream is used for everything.
class HIPAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
hip.hipSetDevice(device)
return hip.hipMalloc(size * dtype.itemsize)
def _do_free(self, buf): hip.hipFree(buf)
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
class _HIP:
def __init__(self, device=None):
self.default_device = device or getenv("HIP_DEFAULT_DEVICE")
self.device_count = 0 if MOCKHIP else hip.hipGetDeviceCount()
if not MOCKHIP: hip.hipSetDevice(self.default_device)
self.allocator = None if MOCKHIP else HIPAllocator(hip.hipGetDeviceProperties(self.default_device).totalGlobalMem)
HIP = _HIP()
class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
def __init__(self, size, dtype, device=HIP.default_device, buf=None, allocator=HIP.allocator): super().__init__(size, dtype, buf=buf, allocator=allocator, **{'device': int(device)})
def _copyin(self, x:np.ndarray):
hip.hipSetDevice(self._device)
hip.hipMemcpyAsync(self._buf, np.require(x, requirements='C').ctypes.data_as(ctypes.c_void_p), self.size * self.dtype.itemsize, hip.hipMemcpyHostToDevice, 0)
def _copyout(self, x:np.ndarray):
hip.hipSetDevice(self._device)
hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
def _transfer(self, x:RawBuffer):
hip.hipSetDevice(x._device)
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
@diskcache
def compile_hip(prg) -> bytes:
prog = hip.hiprtcCreateProgram(prg, "<null>", [], [])
arch = "gfx1100" if MOCKHIP else hip.hipGetDeviceProperties(HIP.default_device).gcnArchName
hip.hiprtcCompileProgram(prog, [f'--offload-arch={arch}'])
hip.hiprtcCompileProgram(prog, [f'--offload-arch={HIPDevice.default_arch_name}'])
return hip.hiprtcGetCode(prog)
def time_execution(cb, enable=False):
@ -67,78 +34,53 @@ def time_execution(cb, enable=False):
return ret
class HIPProgram:
def __init__(self, name:str, prg:bytes):
self.modules, self.prgs, self.c_struct_t = [], [], None
def __init__(self, device:int, name:str, prg:bytes, bufs:int, vars:int=0):
self.device, self.c_struct_t = device, None
if DEBUG >= 6:
asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg))
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
for i in range(HIP.device_count):
hip.hipSetDevice(i)
self.modules.append(hip.hipModuleLoadData(prg))
self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name))
if MOCKHIP: return
hip.hipSetDevice(self.device)
self.module = hip.hipModuleLoadData(prg)
self.prg = hip.hipModuleGetFunction(self.module, name)
self.c_struct_t = hip.getCStructForType([ctypes.c_void_p]*bufs + [ctypes.c_int]*vars)
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
if MOCKHIP: return
hip.hipSetDevice(args[0]._device)
if self.c_struct_t is None: self.c_struct_t = hip.getCStructForType([(ctypes.c_void_p if not isinstance(x, int) else ctypes.c_int) for x in args])
c_params = cast(Callable, self.c_struct_t)(*[x._buf if not isinstance(x, int) else x for x in args])
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prgs[args[0]._device], *global_size, *local_size, 0, 0, c_params), enable=wait)
hip.hipSetDevice(self.device)
c_params = cast(Callable, self.c_struct_t)(*args)
return time_execution(lambda: hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, 0, c_params), enable=wait)
def __del__(self):
for module in self.modules: hip.hipModuleUnload(module)
if MOCKHIP: return
hip.hipModuleUnload(self.module)
class HIPGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
# TODO: Only HIPProgram can be captured for now.
if not all(isinstance(ji.prg, CompiledASTRunner) and isinstance(ji.prg.clprg, HIPProgram) for ji in jit_cache): raise GraphException
T = TypeVar("T")
class HIPAllocator(LRUAllocator):
def __init__(self, device):
self.device = device
super().__init__()
def _alloc(self, size: int, dtype: DType):
if size == 0: return None
hip.hipSetDevice(self.device)
return hip.hipMalloc(size * dtype.itemsize)
def _free(self, opaque:T): hip.hipFree(opaque)
def copyin(self, dest:T, src: memoryview):
hip.hipSetDevice(self.device)
hip.hipMemcpyAsync(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice, 0)
def copyout(self, dest:memoryview, src:T):
hip.hipSetDevice(self.device)
hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)
def transfer(self, dest:T, src:T, sz:int):
hip.hipSetDevice(self.device)
hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice)
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
self.graph, graph_node = hip.hipGraphCreate(), None
self.updatable_nodes: Dict[int, Tuple[Any, hip.kernelNodeParamsWrapper]] = {} # Dict[jc index] = tuple(graph_node, node_params)
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
assert all(x is not None for x in ji.rawbufs) and ji.rawbufs[0] is not None, "buffers could not be None" # for linters
args = [cast(RawBuffer, x)._buf for x in ji.rawbufs] + [var_vals[x] for x in prg.vars]
types = [ctypes.c_void_p] * len(ji.rawbufs) + [ctypes.c_int] * len(prg.vars)
c_params = hip.buildKernelNodeParams(args, types, prg.clprg.prgs[ji.rawbufs[0]._device], *prg.launch_dims(var_vals))
graph_node = hip.hipGraphAddKernelNode(self.graph, [graph_node] if graph_node else [], c_params)
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
self.updatable_nodes[j] = (graph_node, c_params)
self.instance = hip.hipGraphInstantiate(self.graph)
def __del__(self):
hip.hipGraphExecDestroy(self.instance)
hip.hipGraphDestroy(self.graph)
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Update cached params structs with the new values.
for (j,i),input_idx in self.input_replace.items():
hip.setKernelNodeParams(self.updatable_nodes[j][1], [input_rawbuffers[input_idx]._buf], [i])
for j in self.jc_idxs_with_updatable_launch_dims:
hip.setKernelNodeLaunchDims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
for j in self.jc_idxs_with_updatable_var_vals:
prg: CompiledASTRunner = cast(CompiledASTRunner, self.jit_cache[j].prg)
hip.setKernelNodeParams(self.updatable_nodes[j][1], [var_vals[x] for x in prg.vars], list(range(len(self.jit_cache[j].rawbufs), len(self.jit_cache[j].rawbufs) + len(prg.vars))))
# Update graph nodes with the updated structs.
for node, params in self.updatable_nodes.values():
hip.hipGraphExecKernelNodeSetParams(self.instance, node, params)
et = time_execution(lambda: hip.hipGraphLaunch(self.instance), enable=wait)
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et
HIPDevice = Compiled(RawHIPBuffer if not MOCKHIP else RawMallocBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, graph=HIPGraph)
class HIPDevice(Compiled):
default_arch_name = "gfx1100"
def __init__(self, device:str):
self.device = int(device.split(":")[1]) if ":" in device else 0
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = hip.hipGetDeviceProperties(self.device).gcnArchName
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device))
def synchronize(self): hip.hipDeviceSynchronize()

View File

@ -1,11 +1,10 @@
import time, ctypes
from typing import ClassVar
from tinygrad.device import Compiled
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.helpers import getenv, DEBUG, diskcache
from ctypes import CFUNCTYPE
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir
from tinygrad.runtime.lib import RawMallocBuffer
import llvmlite.binding as llvm
@ -55,14 +54,14 @@ def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes:
return LLVM.target_machine.emit_object(mod)
class LLVMProgram:
def __init__(self, name:str, lib:bytes):
def __init__(self, name:str, lib:bytes, bufs:int, vars:int=0):
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = LLVM.engine.get_function_address(name)
self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*bufs), *([ctypes.c_int]*vars))(self.fxn)
def __call__(self, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
if wait: st = time.perf_counter()
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
self.cfunc(*bufs)
if wait: return time.perf_counter()-st
LLVMDevice = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)
LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram)

View File

@ -1,155 +1,77 @@
import os, subprocess, pathlib, ctypes, tempfile
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Any, Tuple, Dict, cast, Optional
from typing import List, Any, Tuple
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.device import Compiled, CompiledASTRunner, update_stats
from tinygrad.helpers import prod, getenv, DEBUG, DType, diskcache, unwrap2
from tinygrad.device import Compiled, LRUAllocator
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs):
buf_len, max_buf_len = size*dtype.itemsize, METAL.device.maxBufferLength()
assert buf_len < max_buf_len, f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB."
buf = METAL.device.newBufferWithLength_options_(buf_len, Metal.MTLResourceStorageModeShared)
assert buf, f"Metal buffer allocation failed with {buf}."
return buf
def _do_free(self, buf): buf.release()
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
class _METAL:
def __init__(self):
self.mtl_buffers_in_flight: List[Any] = []
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.allocator = MetalAllocator(self.device.dedicatedMemorySize() or self.device.sharedMemorySize())
# TODO: is there a better way to do this?
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
self.mtl_buffers_in_flight.clear()
METAL = _METAL()
class RawMetalBuffer(RawBufferMapped):
def __init__(self, size:int, dtype:DType):
assert dtype != dtypes.double, f"METAL does not support {dtype.name}"
super().__init__(size, dtype, allocator=METAL.allocator)
def _buffer(self):
METAL.synchronize()
return self._buf.contents().as_buffer(self._buf.length())
def unwrap(x):
ret, err = x
assert err is None, str(err)
return ret
@diskcache
def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
assert MetalDevice.compiler_device, "metal device creation is required for metal compile"
if use_xcode:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.new()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
library = unwrap2(MetalDevice.compiler_device.newLibraryWithSource_options_error_(prg, options, None))
return library.libraryDataContents().bytes().tobytes()
class MetalProgram:
def __init__(self, name:str, lib:bytes):
def __init__(self, device:MetalDevice, name:str, lib:bytes, bufs:int, vars:int=0):
self.device = device
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
if DEBUG >= 6:
with tempfile.NamedTemporaryFile(delete=True) as shader:
shader.write(lib)
shader.flush()
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer()
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs):
if isinstance(a, RawMetalBuffer): encoder.setBuffer_offset_atIndex_(a._buf, 0, i)
elif isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
else: raise RuntimeError(f"arg at index {i} has unsupported type {type(a)}")
if isinstance(a, int): encoder.setBytes_length_atIndex_((arg:=ctypes.c_int32(a)), ctypes.sizeof(arg), i)
else: encoder.setBuffer_offset_atIndex_(a, 0, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
if wait:
command_buffer.waitUntilCompleted()
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)
self.device.mtl_buffers_in_flight.append(command_buffer)
class MetalGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
class MetalAllocator(LRUAllocator):
def __init__(self, device:MetalDevice):
self.device:MetalDevice = device
super().__init__()
def _alloc(self, size:int, dtype:DType):
if size == 0: return None
ret = self.device.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=} {dtype=}")
return ret
def _free(self, opaque): opaque.release()
def _buffer(self, src):
self.device.synchronize()
return src.contents().as_buffer(src.length())
def copyin(self, dest, src:memoryview): self._buffer(dest)[:] = src
def copyout(self, dest:memoryview, src): dest[:] = self._buffer(src)
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
for i,b in enumerate(ji.rawbufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.toCPU() # TODO: this is metal syncing when it doesn't need to
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers]
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer
if wait:
command_buffer.waitUntilCompleted()
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
et = None
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache))
return et
MetalDevice = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, graph=MetalGraph)
class MetalDevice(Compiled):
compiler_device = None
def __init__(self, device:str):
self.device = Metal.MTLCreateSystemDefaultDevice()
if MetalDevice.compiler_device is None: MetalDevice.compiler_device = self.device
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.mtl_buffers_in_flight: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
self.mtl_buffers_in_flight.clear()

View File

@ -1,24 +1,15 @@
import torch
import numpy as np
from typing import Dict, Callable, Optional
from typing import Dict, Callable
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op
from tinygrad.device import Interpreted
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.device import Interpreted, Allocator
from tinygrad.helpers import getenv, dtypes, DType
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
from tinygrad.runtime.lib import RawBuffer
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
inverse_type_map = {v:k for k,v in type_map.items()}
class RawTorchBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf)
def _copyin(self, x):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
self.size, self.dtype, self._buf = prod(x.shape), type_map[buf.dtype], buf
def _get_buf(self): return self._buf if self._buf is not None else torch.empty([self.size], device=device, dtype=inverse_type_map[self.dtype])
def toCPU(self): return self._get_buf().cpu().numpy()
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
def match_types(x, y, disallow_bool=False):
up = output_type(x, y)
@ -35,7 +26,6 @@ torch_fxn_for_op: Dict[Op, Callable] = {
# TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device),
BufferOps.LOAD: lambda x: x._get_buf(), BufferOps.STORE: lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
@ -51,4 +41,9 @@ torch_fxn_for_op: Dict[Op, Callable] = {
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
}
TorchDevice = Interpreted(RawTorchBuffer, torch_fxn_for_op)
class TorchAllocator(Allocator):
def _alloc(self, size:int, dtype:DType): return torch.empty([size], device=device, dtype=inverse_type_map[dtype])
def copyin(self, dest:torch.Tensor, src:memoryview): dest.copy_(torch.frombuffer(src, dtype=dest.dtype))
def copyout(self, dest:memoryview, src:torch.Tensor): torch.frombuffer(dest, dtype=src.dtype).copy_(src.flatten())
TorchDevice = Interpreted(TorchAllocator(), torch_fxn_for_op)

View File

@ -1,9 +1,7 @@
import numpy as np
import functools
from wgpu.utils.device import get_default_device
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, DType
from tinygrad.device import Compiled
from tinygrad.device import Compiled, Allocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle
from tinygrad.renderer.wgsl import WGSLLanguage
@ -12,11 +10,11 @@ import wgpu
wgpu_device = get_default_device()
class WebGPUProgram:
def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __init__(self, name: str, prg: str, bufs:int=0, vars:int=0): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __call__(self, *bufs, global_size, local_size, wait=False):
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
bindings = [{"binding": i, "resource": {"buffer": x, "offset": 0, "size": x.size}} for i, x in enumerate(bufs)]
bind_group_layout = wgpu_device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = wgpu_device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
bind_group = wgpu_device.create_bind_group(layout=bind_group_layout, entries=bindings)
@ -29,17 +27,14 @@ class WebGPUProgram:
compute_pass.end()
wgpu_device.queue.submit([command_encoder.finish()])
class RawWebGPUAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
WebGPUAlloc = RawWebGPUAllocator(wgpu_device.limits['max_buffer_size'])
class RawWebGPUBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType):
class WebGpuAllocator(Allocator):
def _alloc(self, size: int, dtype: DType):
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64,dtypes.double], f"dtype {dtype} not supported on WEBGPU"
super().__init__(size, dtype, allocator=WebGPUAlloc)
def _copyin(self, x:np.ndarray): wgpu_device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
return wgpu_device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
def copyin(self, dest, src: memoryview): wgpu_device.queue.write_buffer(dest, 0, src)
def copyout(self, dest, src: memoryview): dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
WebGpuDevice = Compiled(RawWebGPUBuffer, LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)
class WebGpuDevice(Compiled):
def __init__(self, device:str):
super().__init__(WebGpuAllocator(), LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]),
functools.partial(uops_to_cstyle, WGSLLanguage()), lambda x: x, WebGPUProgram)

View File

@ -106,10 +106,10 @@ class Tensor:
return self
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK
# TODO: this is a hack for writing to DISK. remove with working assign
if self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
self.contiguous().realize().lazydata.realized._copyin(x.numpy())
self.contiguous().realize().lazydata.realized.copyin(x.numpy().data)
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"