diff --git a/disassemblers/adreno/__init__.py b/disassemblers/adreno/__init__.py index ac16b3811..34741d7fc 100644 --- a/disassemblers/adreno/__init__.py +++ b/disassemblers/adreno/__init__.py @@ -4,15 +4,19 @@ import pathlib from hexdump import hexdump fxn = None + + def disasm(buf): - global fxn - if fxn is None: - shared = pathlib.Path(__file__).parent / "disasm.so" - if not shared.is_file(): - os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so') - fxn = ctypes.CDLL(shared.as_posix())['disasm'] - #hexdump(buf) - END = b"\x00\x00\x00\x00\x00\x00\x00\x03" - buf = buf[0x510:] # this right? - buf = buf.split(END)[0] + END - fxn(buf, len(buf)) + global fxn + if fxn is None: + shared = pathlib.Path(__file__).parent / "disasm.so" + if not shared.is_file(): + os.system( + f"cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so" + ) + fxn = ctypes.CDLL(shared.as_posix())["disasm"] + # hexdump(buf) + END = b"\x00\x00\x00\x00\x00\x00\x00\x03" + buf = buf[0x510:] # this right? + buf = buf.split(END)[0] + END + fxn(buf, len(buf)) diff --git a/docs/abstractions.py b/docs/abstractions.py index 6971389de..623478151 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -23,88 +23,139 @@ from abc import ABC # we will be using the clang backend from tinygrad import Device + Device.DEFAULT = "CLANG" # first, 2+3 as a Tensor, the highest level from tinygrad.tensor import Tensor + a = Tensor([2]) b = Tensor([3]) result = a + b print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}") -assert result.numpy()[0] == 5. +assert result.numpy()[0] == 5.0 # %% # == Tensor (in tinygrad/tensor.py, code 8/10) == # it's worth reading tinygrad/tensor.py. it's pretty beautiful import tinygrad.mlops as mlops + # this is the good old familiar Tensor class class Tensor: - # these two are pretty straightforward - grad: Optional[Tensor] - requires_grad: Optional[bool] + # these two are pretty straightforward + grad: Optional[Tensor] + requires_grad: Optional[bool] - # this is the graph for the autograd engine - _ctx: Optional[Function] + # this is the graph for the autograd engine + _ctx: Optional[Function] - # this is where the data (and other tensor properties) actually live - lazydata: LazyBuffer + # this is where the data (and other tensor properties) actually live + lazydata: LazyBuffer - # high level ops (hlops) are defined on this class. example: relu - def relu(self): return self.maximum(0) + # high level ops (hlops) are defined on this class. example: relu + def relu(self): + return self.maximum(0) + + # log is an mlop, this is the wrapper function in Tensor + def log(self): + return mlops.Log.apply(self) - # log is an mlop, this is the wrapper function in Tensor - def log(self): return mlops.Log.apply(self) # all the definitions of the derivatives are subclasses of Function (like mlops.Log) # there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10) # if you read one file, read mlops.py. if you read two files, also read tinygrad/tensor.py # you can differentiate the world using the chain rule class Function: - # example types of forward and backward - def forward(self, x:LazyBuffer) -> LazyBuffer: pass - def backward(self, x:LazyBuffer) -> LazyBuffer: pass + # example types of forward and backward + def forward(self, x: LazyBuffer) -> LazyBuffer: + pass + + def backward(self, x: LazyBuffer) -> LazyBuffer: + pass + # %% # == LazyBuffer (in tinygrad/lazy.py, code 5/10) == from tinygrad.helpers import DType + # this is where the properties live that you thought were a part of Tensor # LazyBuffer is like a Tensor without derivatives, at the mlop layer class LazyBuffer: - # these three define the "type" of the buffer, and they are returned as Tensor properties - device: str - shape: Tuple[int, ...] - dtype: DType + # these three define the "type" of the buffer, and they are returned as Tensor properties + device: str + shape: Tuple[int, ...] + dtype: DType - # a ShapeTracker is used to track things like reshapes and permutes - # all MovementOps are zero copy in tinygrad! - # the ShapeTracker specifies how the data in the RawBuffer matches to the shape - # we'll come back to this later - st: ShapeTracker + # a ShapeTracker is used to track things like reshapes and permutes + # all MovementOps are zero copy in tinygrad! + # the ShapeTracker specifies how the data in the RawBuffer matches to the shape + # we'll come back to this later + st: ShapeTracker - # if the LazyBuffer is realized, it has a Buffer - # we will come back to Buffer later - realized: Optional[Buffer] + # if the LazyBuffer is realized, it has a Buffer + # we will come back to Buffer later + realized: Optional[Buffer] + + # if the lazybuffer is unrealized, it has a LazyOp + # this LazyOp describes the computation needed to realize this LazyBuffer + op: Optional[LazyOp] - # if the lazybuffer is unrealized, it has a LazyOp - # this LazyOp describes the computation needed to realize this LazyBuffer - op: Optional[LazyOp] # LazyOp (in tinygrad/ops.py, code 4/10) # in a tree they form an Abstract Syntax Tree for a single GPU kernel class LazyOp: - op: Op # the type of the compute - src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources - arg: Optional[Any] = None # and an optional static argument + op: Op # the type of the compute + src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources + arg: Optional[Any] = None # and an optional static argument + # there's currently 26 Ops you have to implement for an accelerator. -class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto() -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto() -class ReduceOps(Enum): SUM = auto(); MAX = auto() -class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() -class TernaryOps(Enum): MULACC = auto(); WHERE = auto() -class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() +class UnaryOps(Enum): + EXP2 = auto() + LOG2 = auto() + CAST = auto() + SIN = auto() + SQRT = auto() + + +class BinaryOps(Enum): + ADD = auto() + SUB = auto() + MUL = auto() + DIV = auto() + CMPLT = auto() + MAX = auto() + + +class ReduceOps(Enum): + SUM = auto() + MAX = auto() + + +class MovementOps(Enum): + RESHAPE = auto() + PERMUTE = auto() + EXPAND = auto() + PAD = auto() + SHRINK = auto() + STRIDE = auto() + + +class TernaryOps(Enum): + MULACC = auto() + WHERE = auto() + + +class LoadOps(Enum): + EMPTY = auto() + CONST = auto() + FROM = auto() + CONTIGUOUS = auto() + CUSTOM = auto() + + # NOTE: if you have a CompiledBuffer(DeviceBuffer) # you do not need to implement the MovementOps # as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10) @@ -135,14 +186,16 @@ assert len(lazyop.src) == 2 # again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first assert lazyop.src[0].op.op == LoadOps.FROM assert lazyop.src[0].op.src[0].device == "CPU" -assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]" +assert ( + lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2 +), "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]" assert result.lazydata.realized is None, "the LazyBuffer is not realized yet" # now we realize the LazyBuffer result.realize() assert result.lazydata.realized is not None, "the LazyBuffer is realized!" # this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass -#assert 'RawMallocBuffer' in str(type(result.lazydata.realized)) +# assert 'RawMallocBuffer' in str(type(result.lazydata.realized)) # getting ahead of ourselves, but we can copy the DeviceBuffer toCPU assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5" @@ -151,41 +204,58 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, # Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend + # Interpreted backends are very simple (example: CPU and TORCH) class Interpreted: - # and they have a lookup table to functions for the Ops - fxn_for_op: Dict[Op, Callable] = { - UnaryOps.EXP2: lambda x: np.exp2(x), - BinaryOps.ADD: lambda x,y: x+y} + # and they have a lookup table to functions for the Ops + fxn_for_op: Dict[Op, Callable] = { + UnaryOps.EXP2: lambda x: np.exp2(x), + BinaryOps.ADD: lambda x, y: x + y, + } + # Compiled backends take a little more (example: GPU and LLVM) class Compiled: - # a code generator, which compiles the AST - codegen: Type[Linearizer] + # a code generator, which compiles the AST + codegen: Type[Linearizer] + + # and a runtime, which runs the generated code + runtime: Type[Runtime] - # and a runtime, which runs the generated code - runtime: Type[Runtime] # Runtime is what actually runs the kernels for a compiled backend class Runtime(ABC): - # `name` is the name of the function, and `prg` is the code - # the constructor compiles the code - def __init__(self, name:str, prg:str): pass - # call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention - def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass + # `name` is the name of the function, and `prg` is the code + # the constructor compiles the code + def __init__(self, name: str, prg: str): + pass + + # call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention + def __call__( + self, + *bufs: List[Buffer], + global_size: Optional[List[int]], + local_size: Optional[List[int]], + ): + pass + # %% # == Buffer (in tinygrad/device.py, code 6/10) == import numpy as np + # Buffer is where the data is actually held. it's pretty close to just memory class Buffer(ABC): - # create an empty rawbuffer that holds `size` elements of type `dtype` - # `opaque` is an opaque container class - def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass + # create an empty rawbuffer that holds `size` elements of type `dtype` + # `opaque` is an opaque container class + def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None): + pass + + # toCPU converts the RawBuffer to a numpy array with shape (size,) + def toCPU(self) -> np.ndarray: + pass - # toCPU converts the RawBuffer to a numpy array with shape (size,) - def toCPU(self) -> np.ndarray: pass # %% # == Example: 2+3 in raw clang == @@ -205,6 +275,7 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang # then we copy the numpy in to RawMallocBuffers # last, we create an empty output buffer from tinygrad.helpers import dtypes + input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4) output = MallocAllocator.alloc(4) @@ -214,12 +285,14 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B")) MallocAllocator.copyin(input_b, numpy_b.data.cast("B")) # compile the program, run it, and 2+3 does indeed equal 5 -program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")) +program = ClangProgram( + "add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}") +) program(output, input_a, input_b) numpy_out = np.empty(1, dtype=np.float32) MallocAllocator.copyout(numpy_out.data.cast("B"), output) assert numpy_out[0] == 5, "it's still 5" -np.testing.assert_allclose(numpy_out, numpy_a+numpy_b) +np.testing.assert_allclose(numpy_out, numpy_a + numpy_b) # %% # == Linearizer (in tinygrad/codegen/linearizer.py, code 4/10) == @@ -229,35 +302,52 @@ np.testing.assert_allclose(numpy_out, numpy_a+numpy_b) # the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST # for that, we use the Linearizer, which turns an AST into a list of (linear) UOps -class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); + +class UOps(Enum): + LOOP = auto() + DEFINE_LOCAL = auto() + LOAD = auto() + ALU = auto() + CONST = auto() + ENDLOOP = auto() + STORE = auto() + class UOp: - uop: UOps - dtype: Optional[DType] - vin: Tuple[UOp, ...] - arg: Any - num: int # UOps are unique + uop: UOps + dtype: Optional[DType] + vin: Tuple[UOp, ...] + arg: Any + num: int # UOps are unique + class Linearizer: - # create the kernel with the AST - # NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change - def __init__(self, ast:LazyOp): pass - def linearize(self): pass + # create the kernel with the AST + # NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change + def __init__(self, ast: LazyOp): + pass + + def linearize(self): + pass + + # when linearize is run, it fills in this list + uops: List[UOp] - # when linearize is run, it fills in this list - uops: List[UOp] from tinygrad.tensor import Tensor + result = Tensor(2).realize() + Tensor(3).realize() # use the real Linearizer to linearize 2+3 from tinygrad.codegen.linearizer import Linearizer + sched = result.lazydata.schedule() linearizer = Linearizer(sched[-1].ast) linearizer.linearize() # print the uops -for uop in linearizer.uops: print(uop) +for uop in linearizer.uops: + print(uop) # output: """ @@ -275,13 +365,15 @@ for uop in linearizer.uops: print(uop) # here, we have an example where we fetch the generated code from the JIT from tinygrad.tensor import Tensor + result = Tensor(2) + Tensor(3) # we have a global cache used by the JIT # from there, we can see the generated clang code from tinygrad.jit import CacheCollector -CacheCollector.start() # enables the cache -result.realize() # create the program and runs it + +CacheCollector.start() # enables the cache +result.realize() # create the program and runs it cache_saved = CacheCollector.finish() # disable the cache # there's one ASTRunner in the cache @@ -310,22 +402,24 @@ from tinygrad.shape.shapetracker import ShapeTracker a = ShapeTracker.from_shape((10, 10)) # you'll see it has one view. the (10, 1 are the strides) -print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) +print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) # we can permute it, and the strides change -a = a.permute((1,0)) -print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) +a = a.permute((1, 0)) +print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) # we can then reshape it, and the strides change again # note how the permute stays applied -a = a.reshape((5,2,5,2)) -print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)]) +a = a.reshape((5, 2, 5, 2)) +print( + a +) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)]) # now, if we were to reshape it to a (100,) shape tensor, we have to create a second view a = a.reshape((100,)) -print(a) # ShapeTracker(shape=(100,), views=[ - # View((5, 2, 5, 2), (2, 1, 20, 10), 0), - # View((100,), (1,), 0)]) +print(a) # ShapeTracker(shape=(100,), views=[ +# View((5, 2, 5, 2), (2, 1, 20, 10), 0), +# View((100,), (1,), 0)]) # Views stack on top of each other, to allow zero copy for any number of MovementOps # we can render a Python expression for the index at any time @@ -333,22 +427,22 @@ idx, _ = a.expr_idxs() print(idx.render()) # (((idx0%10)*10)+(idx0//10)) # of course, if we reshape it back, the indexes get simple again -a = a.reshape((10,10)) +a = a.reshape((10, 10)) idx, _ = a.expr_idxs() print(idx.render()) # ((idx1*10)+idx0) # the ShapeTracker still has two views though... -print(a) # ShapeTracker(shape=(10, 10), views=[ - # View((5, 2, 5, 2), (2, 1, 20, 10), 0), - # View((10, 10), (10, 1), 0)]) +print(a) # ShapeTracker(shape=(10, 10), views=[ +# View((5, 2, 5, 2), (2, 1, 20, 10), 0), +# View((10, 10), (10, 1), 0)]) # ...until we simplify it! a = a.simplify() -print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) +print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)]) # and now we permute it back -a = a.permute((1,0)) -print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) +a = a.permute((1, 0)) +print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)]) # and it's even contiguous assert a.contiguous == True @@ -365,17 +459,17 @@ a = Variable("a", 0, 10) b = Variable("b", 0, 10) # some math examples -print((a*10).min, (a*10).max) # you'll see a*10 has a min of 0 and max of 100 -print((a+b).min, (a+b).max) # 0 20, you get the idea +print((a * 10).min, (a * 10).max) # you'll see a*10 has a min of 0 and max of 100 +print((a + b).min, (a + b).max) # 0 20, you get the idea # but complex expressions are where it gets fun -expr = (a + b*10) % 10 -print(expr.render()) # (a%10) +expr = (a + b * 10) % 10 +print(expr.render()) # (a%10) # as you can see, b is gone! # one more -expr = (a*40 + b) // 20 -print(expr.render()) # (a*2) +expr = (a * 40 + b) // 20 +print(expr.render()) # (a*2) print(expr.min, expr.max) # 0 20 # this is just "(a*2)" # since b only has a range from 0-10, it can't affect the output diff --git a/docs/beautiful.py b/docs/beautiful.py index 83eae774e..b94d0b1de 100644 --- a/docs/beautiful.py +++ b/docs/beautiful.py @@ -15,8 +15,8 @@ a = MallocAllocator.alloc(4) b = MallocAllocator.alloc(4) # load in some values (little endian) -MallocAllocator.copyin(a, bytearray([2,0,0,0])) -MallocAllocator.copyin(b, bytearray([3,0,0,0])) +MallocAllocator.copyin(a, bytearray([2, 0, 0, 0])) +MallocAllocator.copyin(b, bytearray([3, 0, 0, 0])) # compile a program to a binary lib = compile_clang("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }") @@ -34,7 +34,7 @@ assert val == 5 print("******** second, the Device ***********") -DEVICE = "CLANG" # NOTE: you can change this! +DEVICE = "CLANG" # NOTE: you can change this! import struct from tinygrad.helpers import dtypes @@ -49,14 +49,21 @@ b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I", # NOTE: a._buf is the same as the return from MallocAllocator.alloc # describe the computation -ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))) -ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))) +ld_1 = LazyOp( + BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))) +) +ld_2 = LazyOp( + BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))) +) alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2)) -st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))) +st_0 = LazyOp( + BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))) +) # convert the computation to a "linearized" format (print the format) lin = Device[DEVICE].get_linearizer(st_0).linearize() -for u in lin.uops: print(u) +for u in lin.uops: + print(u) # compile a program (and print the source) fxn = Device[DEVICE].to_program(lin) @@ -67,7 +74,7 @@ print(fxn.prg) fxn.exec([out, a, b]) # check the data out -print(val:=out.toCPU().item()) +print(val := out.toCPU().item()) assert val == 5 @@ -79,6 +86,7 @@ from tinygrad.realize import run_schedule # allocate some values + load in values # TODO: remove numpy here import numpy as np + a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE) b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE) @@ -87,10 +95,12 @@ out = a.e(BinaryOps.ADD, b) # schedule the computation as a list of kernels sched = out.schedule() -for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG +for si in sched: + print(si.ast.op) # NOTE: the first two convert it to CLANG # DEBUGGING: print the compute ast as a tree from tinygrad.graph import print_tree + print_tree(sched[-1].ast) # NOTE: sched[-1].ast is the same as st_0 above @@ -98,7 +108,7 @@ print_tree(sched[-1].ast) run_schedule(sched) # check the data out -print(val:=out.realized.toCPU().item()) +print(val := out.realized.toCPU().item()) assert val == 5 @@ -111,5 +121,5 @@ b = Tensor([3], dtype=dtypes.int32, device=DEVICE) out = a + b # check the data out -print(val:=out.item()) +print(val := out.item()) assert val == 5 diff --git a/examples/beautiful_cartpole.py b/examples/beautiful_cartpole.py index f20c84880..a6559b699 100644 --- a/examples/beautiful_cartpole.py +++ b/examples/beautiful_cartpole.py @@ -1,114 +1,135 @@ from typing import Tuple import time from tinygrad import Tensor, TinyJit, nn, Variable -from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype +from tinygrad.helpers import ( + dtypes, +) # TODO: wouldn't need this if argmax returned the right dtype import gymnasium as gym from tqdm import trange import numpy as np # TODO: remove numpy import + class ActorCritic: - def __init__(self, in_features, out_features, hidden_state=32): - self.l1 = nn.Linear(in_features, hidden_state) - self.l2 = nn.Linear(hidden_state, out_features) + def __init__(self, in_features, out_features, hidden_state=32): + self.l1 = nn.Linear(in_features, hidden_state) + self.l2 = nn.Linear(hidden_state, out_features) - self.c1 = nn.Linear(in_features, hidden_state) - self.c2 = nn.Linear(hidden_state, 1) + self.c1 = nn.Linear(in_features, hidden_state) + self.c2 = nn.Linear(hidden_state, 1) - def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]: - x = self.l1(obs).tanh() - act = self.l2(x).log_softmax() - x = self.c1(obs).relu() - return act, self.c2(x) + def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]: + x = self.l1(obs).tanh() + act = self.l2(x).log_softmax() + x = self.c1(obs).relu() + return act, self.c2(x) + + +def evaluate(model: ActorCritic, test_env: gym.Env) -> float: + (obs, _), terminated, truncated = test_env.reset(), False, False + total_rew = 0.0 + while not terminated and not truncated: + act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item() + obs, rew, terminated, truncated, _ = test_env.step(act) + total_rew += float(rew) + return total_rew -def evaluate(model:ActorCritic, test_env:gym.Env) -> float: - (obs, _), terminated, truncated = test_env.reset(), False, False - total_rew = 0.0 - while not terminated and not truncated: - act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item() - obs, rew, terminated, truncated, _ = test_env.step(act) - total_rew += float(rew) - return total_rew # TODO: time should be < 5s on M1 Max if __name__ == "__main__": - env = gym.make('CartPole-v1') + env = gym.make("CartPole-v1") - model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore - opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2) + model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore + opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2) - @TinyJit - def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]: - with Tensor.train(): - log_dist, value = model(x) + @TinyJit + def train_step( + x: Tensor, selected_action: Tensor, reward: Tensor, old_log_dist: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + with Tensor.train(): + log_dist, value = model(x) - # get advantage - advantage = reward.reshape(-1, 1) - value - mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1) - masked_advantage = mask * advantage.detach() + # get advantage + advantage = reward.reshape(-1, 1) - value + mask = selected_action.reshape(-1, 1) == Tensor.arange( + log_dist.shape[1] + ).reshape(1, -1).expand(selected_action.shape[0], -1) + masked_advantage = mask * advantage.detach() - # PPO - ratios = (log_dist - old_log_dist).exp() * masked_advantage - clipped_ratios = ratios.clip(1-0.2, 1+0.2) * masked_advantage - action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean() + # PPO + ratios = (log_dist - old_log_dist).exp() * masked_advantage + clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage + action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean() - entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity - critic_loss = advantage.square().mean() - opt.zero_grad() - (action_loss + entropy_loss*0.0005 + critic_loss).backward() - opt.step() - return action_loss.realize(), entropy_loss.realize(), critic_loss.realize() + entropy_loss = ( + (log_dist.exp() * log_dist).sum(-1).mean() + ) # this encourages diversity + critic_loss = advantage.square().mean() + opt.zero_grad() + (action_loss + entropy_loss * 0.0005 + critic_loss).backward() + opt.step() + return action_loss.realize(), entropy_loss.realize(), critic_loss.realize() - @TinyJit - def get_action_dist(obs:Tensor) -> Tensor: - # TODO: with no_grad - Tensor.no_grad = True - ret = model(obs)[0].exp().realize() - Tensor.no_grad = False - return ret + @TinyJit + def get_action_dist(obs: Tensor) -> Tensor: + # TODO: with no_grad + Tensor.no_grad = True + ret = model(obs)[0].exp().realize() + Tensor.no_grad = False + return ret - BS = 256 - MAX_REPLAY_BUFFER = 2000 - st, steps = time.perf_counter(), 0 - Xn, An, Rn = [], [], [] - for i in (t:=trange(40)): - get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through + BS = 256 + MAX_REPLAY_BUFFER = 2000 + st, steps = time.perf_counter(), 0 + Xn, An, Rn = [], [], [] + for i in (t := trange(40)): + get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through - obs:np.ndarray = env.reset()[0] - rews, terminated, truncated = [], False, False - # NOTE: we don't want to early stop since then the rewards are wrong for the last episode - while not terminated and not truncated: - # pick actions - # TODO: move the multinomial into jitted tinygrad when JIT rand works - # TODO: what's the temperature here? - act = get_action_dist(Tensor(obs)).multinomial().item() + obs: np.ndarray = env.reset()[0] + rews, terminated, truncated = [], False, False + # NOTE: we don't want to early stop since then the rewards are wrong for the last episode + while not terminated and not truncated: + # pick actions + # TODO: move the multinomial into jitted tinygrad when JIT rand works + # TODO: what's the temperature here? + act = get_action_dist(Tensor(obs)).multinomial().item() - # save this state action pair - # TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment - Xn.append(np.copy(obs)) - An.append(act) + # save this state action pair + # TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment + Xn.append(np.copy(obs)) + An.append(act) - obs, rew, terminated, truncated, _ = env.step(act) - rews.append(float(rew)) - steps += len(rews) + obs, rew, terminated, truncated, _ = env.step(act) + rews.append(float(rew)) + steps += len(rews) - # reward to go - # TODO: move this into tinygrad - discounts = np.power(0.99, np.arange(len(rews))) - Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))] + # reward to go + # TODO: move this into tinygrad + discounts = np.power(0.99, np.arange(len(rews))) + Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))] - Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:] - X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn) + Xn, An, Rn = ( + Xn[-MAX_REPLAY_BUFFER:], + An[-MAX_REPLAY_BUFFER:], + Rn[-MAX_REPLAY_BUFFER:], + ) + X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn) - # TODO: make this work - #vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn)) - #X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz) + # TODO: make this work + # vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn)) + # X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz) - old_log_dist = model(X)[0] # TODO: could save these instead of recomputing - for i in range(5): - samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this - # TODO: is this recompiling based on the shape? - action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples]) - t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}") + old_log_dist = model(X)[0] # TODO: could save these instead of recomputing + for i in range(5): + samples = Tensor.randint( + BS, high=X.shape[0] + ).realize() # TODO: remove the need for this + # TODO: is this recompiling based on the shape? + action_loss, entropy_loss, critic_loss = train_step( + X[samples], A[samples], R[samples], old_log_dist[samples] + ) + t.set_description( + f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}" + ) - test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human')) - print(f"test reward: {test_rew}") + test_rew = evaluate(model, gym.make("CartPole-v1", render_mode="human")) + print(f"test reward: {test_rew}") diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index b3195093d..152872440 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -4,42 +4,61 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters from extra.datasets import fetch_mnist from tqdm import trange -class Model: - def __init__(self): - self.layers: List[Callable[[Tensor], Tensor]] = [ - nn.Conv2d(1, 32, 5), Tensor.relu, - nn.Conv2d(32, 32, 5), Tensor.relu, - nn.BatchNorm2d(32), Tensor.max_pool2d, - nn.Conv2d(32, 64, 3), Tensor.relu, - nn.Conv2d(64, 64, 3), Tensor.relu, - nn.BatchNorm2d(64), Tensor.max_pool2d, - lambda x: x.flatten(1), nn.Linear(576, 10)] - def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) +class Model: + def __init__(self): + self.layers: List[Callable[[Tensor], Tensor]] = [ + nn.Conv2d(1, 32, 5), + Tensor.relu, + nn.Conv2d(32, 32, 5), + Tensor.relu, + nn.BatchNorm2d(32), + Tensor.max_pool2d, + nn.Conv2d(32, 64, 3), + Tensor.relu, + nn.Conv2d(64, 64, 3), + Tensor.relu, + nn.BatchNorm2d(64), + Tensor.max_pool2d, + lambda x: x.flatten(1), + nn.Linear(576, 10), + ] + + def __call__(self, x: Tensor) -> Tensor: + return x.sequential(self.layers) + if __name__ == "__main__": - X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True) + X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True) - model = Model() - opt = nn.optim.Adam(nn.state.get_parameters(model)) + model = Model() + opt = nn.optim.Adam(nn.state.get_parameters(model)) - # TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int - @TinyJit - def train_step(samples:Tensor) -> Tensor: - with Tensor.train(): - opt.zero_grad() - # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed - loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward() - opt.step() - return loss.realize() + # TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int + @TinyJit + def train_step(samples: Tensor) -> Tensor: + with Tensor.train(): + opt.zero_grad() + # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed + loss = ( + model(X_train[samples]) + .sparse_categorical_crossentropy(Y_train[samples]) + .backward() + ) + opt.step() + return loss.realize() - @TinyJit - def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize() + @TinyJit + def get_test_acc() -> Tensor: + return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize() - test_acc = float('nan') - for i in (t:=trange(70)): - GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing - samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed - loss = train_step(samples) - if i%10 == 9: test_acc = get_test_acc().item() - t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") + test_acc = float("nan") + for i in (t := trange(70)): + GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing + samples = Tensor.randint( + 512, high=X_train.shape[0] + ) # TODO: put this in the JIT when rand is fixed + loss = train_step(samples) + if i % 10 == 9: + test_acc = get_test_acc().item() + t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py index a28f013ae..690e6f7a8 100755 --- a/examples/benchmark_train_efficientnet.py +++ b/examples/benchmark_train_efficientnet.py @@ -10,8 +10,10 @@ from tinygrad.helpers import GlobalCounters from tinygrad.helpers import getenv from tinygrad.jit import CacheCollector + def tensors_allocated(): - return sum(isinstance(x, Tensor) for x in gc.get_objects()) + return sum(isinstance(x, Tensor) for x in gc.get_objects()) + NUM = getenv("NUM", 2) BS = getenv("BS", 8) @@ -22,46 +24,53 @@ ADAM = getenv("ADAM", 0) CLCACHE = getenv("CLCACHE", 0) if __name__ == "__main__": - print(f"NUM:{NUM} BS:{BS} CNT:{CNT}") - model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False) - parameters = get_parameters(model) - for p in parameters: p.realize() - if ADAM: optimizer = optim.Adam(parameters, lr=0.001) - else: optimizer = optim.SGD(parameters, lr=0.001) - - Tensor.training = TRAINING - Tensor.no_grad = not BACKWARD - for i in trange(CNT): - GlobalCounters.reset() - cpy = time.monotonic() - x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize() - y_train = Tensor.randn(BS, 1000, requires_grad=False).realize() - - # TODO: replace with TinyJit - if i < 3 or not CLCACHE: - st = time.monotonic() - out = model.forward(x_train) - loss = out.log_softmax().mul(y_train).mean() - if i == 2 and CLCACHE: CacheCollector.start() - if BACKWARD: - optimizer.zero_grad() - loss.backward() - optimizer.step() - mt = time.monotonic() - loss.realize() - for p in parameters: + print(f"NUM:{NUM} BS:{BS} CNT:{CNT}") + model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False) + parameters = get_parameters(model) + for p in parameters: p.realize() - et = time.monotonic() + if ADAM: + optimizer = optim.Adam(parameters, lr=0.001) else: - st = mt = time.monotonic() - for prg, args in cl_cache: prg(*args) - et = time.monotonic() + optimizer = optim.SGD(parameters, lr=0.001) - if i == 2 and CLCACHE: - cl_cache = CacheCollector.finish() + Tensor.training = TRAINING + Tensor.no_grad = not BACKWARD + for i in trange(CNT): + GlobalCounters.reset() + cpy = time.monotonic() + x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize() + y_train = Tensor.randn(BS, 1000, requires_grad=False).realize() - mem_used = GlobalCounters.mem_used - loss_cpu = loss.detach().numpy() - cl = time.monotonic() + # TODO: replace with TinyJit + if i < 3 or not CLCACHE: + st = time.monotonic() + out = model.forward(x_train) + loss = out.log_softmax().mul(y_train).mean() + if i == 2 and CLCACHE: + CacheCollector.start() + if BACKWARD: + optimizer.zero_grad() + loss.backward() + optimizer.step() + mt = time.monotonic() + loss.realize() + for p in parameters: + p.realize() + et = time.monotonic() + else: + st = mt = time.monotonic() + for prg, args in cl_cache: + prg(*args) + et = time.monotonic() - print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") + if i == 2 and CLCACHE: + cl_cache = CacheCollector.finish() + + mem_used = GlobalCounters.mem_used + loss_cpu = loss.detach().numpy() + cl = time.monotonic() + + print( + f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS" + ) diff --git a/examples/coder.py b/examples/coder.py index 488237c3e..5962152b9 100644 --- a/examples/coder.py +++ b/examples/coder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import os, sys, traceback + sys.path.append(os.getcwd()) from io import StringIO @@ -9,99 +10,148 @@ from tinygrad.helpers import Timing, colored, getenv, fetch from extra.models.llama import Transformer, convert_from_huggingface from sentencepiece import SentencePieceProcessor + def create_fixed_tokenizer(output_file): - print("creating fixed tokenizer") - import extra.junk.sentencepiece_model_pb2 as spb2 - mp = spb2.ModelProto() - mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes()) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) - with open(output_file, "wb") as f: - f.write(mp.SerializeToString()) + print("creating fixed tokenizer") + import extra.junk.sentencepiece_model_pb2 as spb2 + + mp = spb2.ModelProto() + mp.ParseFromString( + fetch( + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true" + ).read_bytes() + ) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) + with open(output_file, "wb") as f: + f.write(mp.SerializeToString()) + # TODO: make loading bf16 fast so we can remove this def create_model_cache(output_file, model): - print(f"creating model cache at {output_file}") - # TODO: add read only Tensors - with Timing("download weights: "): - part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true")) - part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true")) + print(f"creating model cache at {output_file}") + # TODO: add read only Tensors + with Timing("download weights: "): + part1 = nn.state.torch_load( + fetch( + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true" + ) + ) + part2 = nn.state.torch_load( + fetch( + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true" + ) + ) - with Timing("weights -> model: "): - nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False) - nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False) + with Timing("weights -> model: "): + nn.state.load_state_dict( + model, convert_from_huggingface(part1, model, 32, 8), strict=False + ) + nn.state.load_state_dict( + model, convert_from_huggingface(part2, model, 32, 8), strict=False + ) - with Timing("saving float16 cache: "): - nn.state.safe_save(nn.state.get_state_dict(model), output_file) + with Timing("saving float16 cache: "): + nn.state.safe_save(nn.state.get_state_dict(model), output_file) + + print("cache created, rerun to use") + exit(0) - print("cache created, rerun to use") - exit(0) if __name__ == "__main__": - Tensor.no_grad = True + Tensor.no_grad = True - # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json - with Timing("create model: "): - model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096) + # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json + with Timing("create model: "): + model = Transformer( + 4096, + 14336, + n_heads=32, + n_layers=32, + norm_eps=1e-5, + vocab_size=32002, + n_kv_heads=8, + max_context=4096, + ) - cached_model = "/tmp/cached_openhermes.safetensors" - if not os.path.isfile(cached_model): create_model_cache(cached_model, model) - with Timing("loading float16 cache: "): - nn.state.load_state_dict(model, nn.state.safe_load(cached_model)) + cached_model = "/tmp/cached_openhermes.safetensors" + if not os.path.isfile(cached_model): + create_model_cache(cached_model, model) + with Timing("loading float16 cache: "): + nn.state.load_state_dict(model, nn.state.safe_load(cached_model)) - if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model") - spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model") + if not os.path.isfile("/tmp/tokenizer.model"): + create_fixed_tokenizer("/tmp/tokenizer.model") + spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model") - # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json - # "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - IM_END = 32000 - IM_START = 32001 - def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n") - def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n") - def output(outputted, toks, color): - cur = spp.decode(toks)[len(outputted):] - sys.stdout.write(colored(cur, color)) - sys.stdout.flush() - outputted += cur - return outputted + # https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json + # "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + IM_END = 32000 + IM_START = 32001 - # *** app below this line *** + def encode_prompt(k, v): + return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n") - toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input") + def start_prompt(k): + return [IM_START] + spp.encode(f"{k}\n") - PROMPT = getenv("PROMPT", 1) - temperature = getenv("TEMP", 0.7) + def output(outputted, toks, color): + cur = spp.decode(toks)[len(outputted) :] + sys.stdout.write(colored(cur, color)) + sys.stdout.flush() + outputted += cur + return outputted - start_pos = 0 - outputted = output("", toks, "green") - turn = True - while 1: - if PROMPT: - toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant") - else: - toks += start_prompt("user" if turn else "assistant") - turn = not turn - old_output_len = len(outputted) + # *** app below this line *** + + toks = [spp.bos_id()] + encode_prompt( + "system", + "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input", + ) + + PROMPT = getenv("PROMPT", 1) + temperature = getenv("TEMP", 0.7) + + start_pos = 0 + outputted = output("", toks, "green") + turn = True while 1: - tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item() - start_pos = len(toks) - toks.append(tok) - outputted = output(outputted, toks, "blue" if not turn else "cyan") - if tok == IM_END: break - if tok == spp.eos_id(): break - new_output = outputted[old_output_len:] + if PROMPT: + toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant") + else: + toks += start_prompt("user" if turn else "assistant") + turn = not turn + old_output_len = len(outputted) + while 1: + tok = ( + model(Tensor([toks[start_pos:]]), start_pos, temperature) + .multinomial() + .item() + ) + start_pos = len(toks) + toks.append(tok) + outputted = output(outputted, toks, "blue" if not turn else "cyan") + if tok == IM_END: + break + if tok == spp.eos_id(): + break + new_output = outputted[old_output_len:] - if new_output.endswith("```") and '```python\n' in new_output: - python_code = new_output.split('```python\n')[1].split("```")[0] - # AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things. - if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y': - my_stdout = StringIO() - try: - with redirect_stdout(my_stdout): exec(python_code) - result = my_stdout.getvalue() - except Exception as e: - result = ''.join(traceback.format_exception_only(e)) - toks += spp.encode(f"\nOutput:\n```\n{result}```") - outputted = output(outputted, toks, "yellow") - old_output_len = len(outputted) - print("") \ No newline at end of file + if new_output.endswith("```") and "```python\n" in new_output: + python_code = new_output.split("```python\n")[1].split("```")[0] + # AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things. + if ( + input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() + == "y" + ): + my_stdout = StringIO() + try: + with redirect_stdout(my_stdout): + exec(python_code) + result = my_stdout.getvalue() + except Exception as e: + result = "".join(traceback.format_exception_only(e)) + toks += spp.encode(f"\nOutput:\n```\n{result}```") + outputted = output(outputted, toks, "yellow") + old_output_len = len(outputted) + print("") diff --git a/examples/compile_efficientnet.py b/examples/compile_efficientnet.py index 16c54ff9d..2f4d93023 100644 --- a/examples/compile_efficientnet.py +++ b/examples/compile_efficientnet.py @@ -7,32 +7,54 @@ from tinygrad.helpers import getenv, fetch import ast if __name__ == "__main__": - model = EfficientNet(0) - model.load_from_pretrained() - mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else "" - prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224)) - dirname = Path(__file__).parent - if getenv("CLANG", "") == "": - safe_save(state, (dirname / "net.safetensors").as_posix()) - ext = "js" if getenv("WEBGPU", "") != "" else "json" - with open(dirname / f"net.{ext}", "w") as text_file: - text_file.write(prg) - else: - cprog = [prg] - # image library! - cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")] + model = EfficientNet(0) + model.load_from_pretrained() + mode = ( + "clang" + if getenv("CLANG", "") != "" + else "webgpu" + if getenv("WEBGPU", "") != "" + else "" + ) + prg, inp_sizes, out_sizes, state = export_model( + model, mode, Tensor.randn(1, 3, 224, 224) + ) + dirname = Path(__file__).parent + if getenv("CLANG", "") == "": + safe_save(state, (dirname / "net.safetensors").as_posix()) + ext = "js" if getenv("WEBGPU", "") != "" else "json" + with open(dirname / f"net.{ext}", "w") as text_file: + text_file.write(prg) + else: + cprog = [prg] + # image library! + cprog += [ + "#define STB_IMAGE_IMPLEMENTATION", + fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h") + .read_text() + .replace("half", "_half"), + ] - # imagenet labels, move to datasets? - lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) - lbls = ['"'+lbls[i]+'"' for i in range(1000)] - inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()]) - outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()]) - cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};") - cprog.append(inputs) - cprog.append(outputs) + # imagenet labels, move to datasets? + lbls = ast.literal_eval( + fetch( + "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt" + ).read_text() + ) + lbls = ['"' + lbls[i] + '"' for i in range(1000)] + inputs = "\n".join( + [f"float {inp}[{inp_size}];" for inp, inp_size in inp_sizes.items()] + ) + outputs = "\n".join( + [f"float {out}[{out_size}];" for out, out_size in out_sizes.items()] + ) + cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};") + cprog.append(inputs) + cprog.append(outputs) - # buffers (empty + weights) - cprog.append(""" + # buffers (empty + weights) + cprog.append( + """ int main(int argc, char* argv[]) { int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0; int X=0, Y=0, chan=0; @@ -62,8 +84,9 @@ if __name__ == "__main__": } if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best); else printf("%s\\n", lbls[best_idx]); - }""") + }""" + ) - # CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg - # category : 281 (tabby, tabby cat) with 9.452788 - print('\n'.join(cprog)) + # CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg + # category : 281 (tabby, tabby cat) with 9.452788 + print("\n".join(cprog)) diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index 43e5685b2..aaa867ebc 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -1,8 +1,9 @@ # An example to compile a small Tensorflow model to extremely portable C code import os, sys -os.environ["CLANG"] = '1' -os.environ["GPU"] = '1' + +os.environ["CLANG"] = "1" +os.environ["GPU"] = "1" import numpy as np import subprocess @@ -12,55 +13,66 @@ from examples.compile_efficientnet import compile_net from extra.onnx import get_run_onnx from tinygrad.tensor import Tensor + def get_uncompiled_model2(dataset_size=32, output_size=4): - inputs = tf.keras.Input(shape=(dataset_size,), name="inputs") - x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs) - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x) - outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x) - model = tf.keras.Model(inputs=inputs, outputs=outputs) - return model + inputs = tf.keras.Input(shape=(dataset_size,), name="inputs") + x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x) + outputs = tf.keras.layers.Dense( + output_size, activation="sigmoid", name="predictions" + )(x) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + return model + def create_onnx_model(keras_model): - input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] - onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) - return onnx_model + input_signature = [tf.TensorSpec([1, 32], tf.float32, name="x")] + onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) + return onnx_model + def compile_onnx_model(onnx_model): - run_onnx = get_run_onnx(onnx_model) + run_onnx = get_run_onnx(onnx_model) - from tinygrad.jit import TinyJit - @TinyJit - def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize() + from tinygrad.jit import TinyJit - the_input = Tensor.randn(1,32) - the_output = run(the_input) - the_output = run(the_input) + @TinyJit + def run(x): + return run_onnx({"x": x}, debug=False)["predictions"].realize() - special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"} - cprog, statements, bufs, bufs_to_save = compile_net(run, special_names) - cprog = ["#include ", "#include ", "#include "] + cprog + the_input = Tensor.randn(1, 32) + the_output = run(the_input) + the_output = run(the_input) - # buffers (all except input) - cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] != "input"] + special_names = { + id(the_input.lazydata.realized.cl): "input", + id(the_output.lazydata.realized.cl): "outputs", + } + cprog, statements, bufs, bufs_to_save = compile_net(run, special_names) + cprog = ["#include ", "#include ", "#include "] + cprog - # weights - cprog.append("void initialize(float *weights) {") - weights = bytes() - for name,cl in bufs_to_save.items(): - cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl)});") - weights += bytes(memoryview(cl)[0:len(cl)//4]) - cprog.append("}") + # buffers (all except input) + cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] != "input"] - # write the weights to disk - with open("/tmp/tf_weights", "wb") as f: - f.write(weights) + # weights + cprog.append("void initialize(float *weights) {") + weights = bytes() + for name, cl in bufs_to_save.items(): + cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl)});") + weights += bytes(memoryview(cl)[0 : len(cl) // 4]) + cprog.append("}") - # the net - cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"] + # write the weights to disk + with open("/tmp/tf_weights", "wb") as f: + f.write(weights) - # test program - cprog.append(f"""int main(int argc, char *argv[]) {{ + # the net + cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"] + + # test program + cprog.append( + f"""int main(int argc, char *argv[]) {{ // read in the weights from disk FILE *f = fopen("/tmp/tf_weights", "rb"); float *weights = (float *)malloc({len(weights)}); @@ -75,30 +87,42 @@ def compile_onnx_model(onnx_model): for (int i = 0; i < 32; i++) scanf("%f", &input[i]); float *outputs = infer(input); printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]); - }}""") + }}""" + ) - # ready the program - prg = '\n'.join(cprog) - print(prg) + # ready the program + prg = "\n".join(cprog) + print(prg) - # add test weights - subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8')) + # add test weights + subprocess.check_output( + ["clang", "-O2", "-lm", "-fPIC", "-x", "c", "-", "-o", "/tmp/tf_test"], + input=prg.encode("utf-8"), + ) - tinygrad_output = [x for x in the_output.numpy()[0]] - print("tinygrad:", tinygrad_output, file=sys.stderr) + tinygrad_output = [x for x in the_output.numpy()[0]] + print("tinygrad:", tinygrad_output, file=sys.stderr) - c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n" - c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")] - print("compiled:", c_output, file=sys.stderr) + c_input = " ".join(["%f" % x for x in the_input[0].numpy()]) + "\n" + c_output = [ + float(x) + for x in subprocess.check_output( + ["/tmp/tf_test"], input=c_input.encode("utf-8") + ) + .decode("utf-8") + .strip() + .split(" ") + ] + print("compiled:", c_output, file=sys.stderr) + + np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5) + return the_input.numpy(), c_output - np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5) - return the_input.numpy(), c_output if __name__ == "__main__": - keras_model = get_uncompiled_model2() - onnx_model = create_onnx_model(keras_model) - test_input, test_output = compile_onnx_model(onnx_model) - tf_output = keras_model(test_input).numpy()[0] - print("keras: ", tf_output, file=sys.stderr) - np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5) - + keras_model = get_uncompiled_model2() + onnx_model = create_onnx_model(keras_model) + test_input, test_output = compile_onnx_model(onnx_model) + tf_output = keras_model(test_input).numpy()[0] + print("keras: ", tf_output, file=sys.stderr) + np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5) diff --git a/examples/conversation.py b/examples/conversation.py index 0e2d17bb6..27463bb1a 100644 --- a/examples/conversation.py +++ b/examples/conversation.py @@ -12,7 +12,14 @@ import pyaudio import yaml from llama import LLaMa from vits import MODELS as VITS_MODELS -from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model +from vits import ( + Y_LENGTH_ESTIMATE_SCALARS, + HParams, + Synthesizer, + TextMapper, + get_hparams_from_file, + load_model, +) from whisper import init_whisper, transcribe_waveform from sentencepiece import SentencePieceProcessor @@ -29,316 +36,557 @@ IM_END = 32002 # Functions for encoding prompts to chatml md -def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n") -def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n") +def encode_prompt(spp, k, v): + return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n") + + +def start_prompt(spp, k): + return [IM_START] + spp.encode(f"{k}\n") + def chunks(lst, n): - for i in range(0, len(lst), n): yield lst[i:i + n] + for i in range(0, len(lst), n): + yield lst[i : i + n] + def create_fixed_tokenizer(): - """Function needed for extending tokenizer with additional chat tokens""" - import extra.junk.sentencepiece_model_pb2 as spb2 - tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model") - if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003: - print("creating fixed tokenizer") - mp = spb2.ModelProto() - mp.ParseFromString(tokenizer_path.read_bytes()) - # https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0)) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) - mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) - tokenizer_path.write_bytes(mp.SerializeToString()) - return tokenizer_path + """Function needed for extending tokenizer with additional chat tokens""" + import extra.junk.sentencepiece_model_pb2 as spb2 + + tokenizer_path = fetch( + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model" + ) + if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003: + print("creating fixed tokenizer") + mp = spb2.ModelProto() + mp.ParseFromString(tokenizer_path.read_bytes()) + # https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0)) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0)) + mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0)) + tokenizer_path.write_bytes(mp.SerializeToString()) + return tokenizer_path + + +def llama_prepare( + llama: LLaMa, temperature: float, pre_prompt_path: Path +) -> tuple[list[int], str, str, str]: + """Prepares a llama model from a specified pre-prompt file""" + with open(str(pre_prompt_path)) as f: + config = yaml.safe_load(f.read()) + toks = [llama.tokenizer.bos_id()] + encode_prompt( + llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ") + ) + for i in config["examples"]: + toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"]) + toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"]) + llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used + return ( + toks, + config["user_delim"], + config["resp_delim"], + len(toks), + llama.tokenizer.decode(toks), + ) -def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]: - """Prepares a llama model from a specified pre-prompt file""" - with open(str(pre_prompt_path)) as f: - config = yaml.safe_load(f.read()) - toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")) - for i in config["examples"]: - toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"]) - toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"]) - llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used - return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks) def llama_generate( - llama: LLaMa, - toks: list[int], - outputted: str, - prompt: str, - start_pos: int, - user_delim: str, - resp_delim: str, - temperature=0.7, - max_tokens=1000 + llama: LLaMa, + toks: list[int], + outputted: str, + prompt: str, + start_pos: int, + user_delim: str, + resp_delim: str, + temperature=0.7, + max_tokens=1000, ): - """Generates an output for the specified prompt""" - toks += encode_prompt(llama.tokenizer, user_delim, prompt) - toks += start_prompt(llama.tokenizer, resp_delim) + """Generates an output for the specified prompt""" + toks += encode_prompt(llama.tokenizer, user_delim, prompt) + toks += start_prompt(llama.tokenizer, resp_delim) - outputted = llama.tokenizer.decode(toks) - init_length = len(outputted) - for _ in range(max_tokens): - probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy() - token = int(np.random.choice(len(probs_np), p=probs_np)) - start_pos = len(toks) - toks.append(token) + outputted = llama.tokenizer.decode(toks) + init_length = len(outputted) + for _ in range(max_tokens): + probs_np = llama.model( + Tensor([toks[start_pos:]]), start_pos, temperature + ).numpy() + token = int(np.random.choice(len(probs_np), p=probs_np)) + start_pos = len(toks) + toks.append(token) - cur = llama.tokenizer.decode(toks) + cur = llama.tokenizer.decode(toks) + + # Print is just for debugging + sys.stdout.write(cur[len(outputted) :]) + sys.stdout.flush() + outputted = cur + if toks[-1] == IM_END: + break + else: + toks.append(IM_END) + print() # because the output is flushed + return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "") - # Print is just for debugging - sys.stdout.write(cur[len(outputted):]) - sys.stdout.flush() - outputted = cur - if toks[-1] == IM_END: break - else: - toks.append(IM_END) - print() # because the output is flushed - return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "") def tts( - text_to_synthesize: str, - synth: Synthesizer, - hps: HParams, - emotion_embedding: Path, - speaker_id: int, - model_to_use: str, - noise_scale: float, - noise_scale_w: float, - length_scale: float, - estimate_max_y_length: bool, - text_mapper: TextMapper, - model_has_multiple_speakers: bool, - batch_size=600, - vits_batch_size=1000 + text_to_synthesize: str, + synth: Synthesizer, + hps: HParams, + emotion_embedding: Path, + speaker_id: int, + model_to_use: str, + noise_scale: float, + noise_scale_w: float, + length_scale: float, + estimate_max_y_length: bool, + text_mapper: TextMapper, + model_has_multiple_speakers: bool, + batch_size=600, + vits_batch_size=1000, ): - if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower()) + if model_to_use == "mmts-tts": + text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower()) - # Convert the input text to a tensor. - stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners) - init_shape = stn_tst.shape - assert init_shape[0] < batch_size, "text is too long" - x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64) - sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None + # Convert the input text to a tensor. + stn_tst = text_mapper.get_text( + text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners + ) + init_shape = stn_tst.shape + assert init_shape[0] < batch_size, "text is too long" + x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze( + 0 + ), Tensor([init_shape[0]], dtype=dtypes.int64) + sid = ( + Tensor([speaker_id], dtype=dtypes.int64) + if model_has_multiple_speakers + else None + ) + + # Perform inference. + audio_tensor = synth.infer( + x_tst, + x_tst_lengths, + sid, + noise_scale, + length_scale, + noise_scale_w, + emotion_embedding=emotion_embedding, + max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] + if estimate_max_y_length + else None, + batch_size=vits_batch_size, + )[0, 0] + # Save the audio output. + audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16) + return audio_data - # Perform inference. - audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding, - max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0] - # Save the audio output. - audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16) - return audio_data def init_vits( - model_to_use: str, - emotion_path: Path, - speaker_id: int, - seed: int, + model_to_use: str, + emotion_path: Path, + speaker_id: int, + seed: int, ): - model_config = VITS_MODELS[model_to_use] + model_config = VITS_MODELS[model_to_use] - # Load the hyperparameters from the config file. - hps = get_hparams_from_file(fetch(model_config[0])) + # Load the hyperparameters from the config file. + hps = get_hparams_from_file(fetch(model_config[0])) - # If model has multiple speakers, validate speaker id and retrieve name if available. - model_has_multiple_speakers = hps.data.n_speakers > 0 - if model_has_multiple_speakers: - if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.") - if hps.__contains__("speakers"): # maps speaker ids to names - speakers = hps.speakers - if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)} + # If model has multiple speakers, validate speaker id and retrieve name if available. + model_has_multiple_speakers = hps.data.n_speakers > 0 + if model_has_multiple_speakers: + if speaker_id >= hps.data.n_speakers: + raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.") + if hps.__contains__("speakers"): # maps speaker ids to names + speakers = hps.speakers + if isinstance(speakers, list): + speakers = {speaker: i for i, speaker in enumerate(speakers)} - # Load emotions if any. TODO: find an english model with emotions, this is untested atm. - emotion_embedding = None - if emotion_path is not None: - if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0) - else: raise ValueError("Emotion path must be a .npy file.") + # Load emotions if any. TODO: find an english model with emotions, this is untested atm. + emotion_embedding = None + if emotion_path is not None: + if emotion_path.endswith(".npy"): + emotion_embedding = Tensor( + np.load(emotion_path), dtype=dtypes.int64 + ).unsqueeze(0) + else: + raise ValueError("Emotion path must be a .npy file.") - # Load symbols, instantiate TextMapper and clean the text. - if hps.__contains__("symbols"): symbols = hps.symbols - elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()] - else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ") - text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) + # Load symbols, instantiate TextMapper and clean the text. + if hps.__contains__("symbols"): + symbols = hps.symbols + elif model_to_use == "mmts-tts": + symbols = [ + x.replace("\n", "") + for x in fetch( + "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt" + ) + .open(encoding="utf-8") + .readlines() + ] + else: + symbols = ( + ["_"] + + list(';:,.!?¡¿—…"«»“” ') + + list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + + list( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + ) + ) + text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) - # Load the model. - Tensor.no_grad = True - if seed is not None: - Tensor.manual_seed(seed) - np.random.seed(seed) - net_g = load_model(text_mapper.symbols, hps, model_config) + # Load the model. + Tensor.no_grad = True + if seed is not None: + Tensor.manual_seed(seed) + np.random.seed(seed) + net_g = load_model(text_mapper.symbols, hps, model_config) + + return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers - return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers @contextmanager def output_stream(num_channels: int, sample_rate: int): - try: - p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True) - yield stream - except KeyboardInterrupt: pass - finally: - stream.stop_stream() - stream.close() - p.terminate() + try: + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True + ) + yield stream + except KeyboardInterrupt: + pass + finally: + stream.stop_stream() + stream.close() + p.terminate() + @contextmanager def log_writer(): - try: - logs = [] - yield logs - finally: - sep = "="*os.get_terminal_size()[1] - print(f"{sep[:-1]}\nCHAT LOG") - print(*logs, sep="\n") - print(sep) + try: + logs = [] + yield logs + finally: + sep = "=" * os.get_terminal_size()[1] + print(f"{sep[:-1]}\nCHAT LOG") + print(*logs, sep="\n") + print(sep) + def listener(q: mp.Queue, event: mp.Event): - try: - p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK) - did_print = False - while True: - data = stream.read(CHUNK) # read data to avoid overflow - if event.is_set(): - if not did_print: - print("listening") - did_print = True - q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)) - else: + try: + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt16, + channels=1, + rate=RATE, + input=True, + frames_per_buffer=CHUNK, + ) did_print = False - finally: - stream.stop_stream() - stream.close() - p.terminate() + while True: + data = stream.read(CHUNK) # read data to avoid overflow + if event.is_set(): + if not did_print: + print("listening") + did_print = True + q.put(((np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3)) + else: + did_print = False + finally: + stream.stop_stream() + stream.close() + p.terminate() + + +def mp_output_stream( + q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int +): + with output_stream(num_channels, sample_rate) as stream: + while True: + try: + stream.write(q.get()) + counter.value += 1 + except KeyboardInterrupt: + break -def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int): - with output_stream(num_channels, sample_rate) as stream: - while True: - try: - stream.write(q.get()) - counter.value += 1 - except KeyboardInterrupt: - break if __name__ == "__main__": - import nltk - nltk.download("punkt") - Tensor.no_grad = True - # Parse CLI arguments - parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad") + import nltk - # Whisper args - parser.add_argument("--whisper_model_name", type=str, default="tiny.en") + nltk.download("punkt") + Tensor.no_grad = True + # Parse CLI arguments + parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad") - # LLAMA args - parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ") - parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate") - parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax") - parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory") - parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file") - parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use") - parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use") - parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model") + # Whisper args + parser.add_argument("--whisper_model_name", type=str, default="tiny.en") - # vits args - parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.") - parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.") - parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.") - parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.") - parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.") - parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.") - parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.") - parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.") - parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.") - parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.") - parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.") - - # conversation args - parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits") - - args = parser.parse_args() - - # Init models - model, enc = init_whisper(args.whisper_model_name) - synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed) - - # Download tinyllama chat as a default model - if args.llama_model is None: - args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors") - args.llama_gen = "tiny" - args.llama_size = "1B-Chat" - # Add 3 more tokens to the tokenizer - if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer() - tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model" - llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize) - toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path) - - # Start child process for mic input - q = mp.Queue() - is_listening_event = mp.Event() - p = mp.Process(target=listener, args=(q, is_listening_event,)) - p.daemon = True - p.start() - - # Start child process for speaker output - out_q = mp.Queue() - out_counter = mp.Value("i", 0) - out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,)) - out_p.daemon = True - out_p.start() - - # JIT tts - for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]: - tts( - i, synth, hps, emotion_embedding, - args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale, - args.vits_noise_scale_w, args.vits_length_scale, - args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers + # LLAMA args + parser.add_argument( + "--llama_pre_prompt_path", + type=Path, + default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", + help="Path to yaml file which contains all pre-prompt data needed. ", + ) + parser.add_argument( + "--llama_count", type=int, default=1000, help="Max number of tokens to generate" + ) + parser.add_argument( + "--llama_temperature", + type=float, + default=0.7, + help="Temperature in the softmax", + ) + parser.add_argument( + "--llama_quantize", + action="store_true", + help="Quantize the weights to int8 in memory", + ) + parser.add_argument( + "--llama_model", + type=Path, + default=None, + help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file", + ) + parser.add_argument( + "--llama_gen", + type=str, + default="tiny", + required=False, + help="Generation of the model to use", + ) + parser.add_argument( + "--llama_size", + type=str, + default="1B-Chat", + required=False, + help="Size of model to use", + ) + parser.add_argument( + "--llama_tokenizer", + type=Path, + default=None, + required=False, + help="Path to llama tokenizer.model", ) - # Start the pipeline - with log_writer() as log: - while True: - tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]] - total = np.array([]) - out_counter.value = 0 + # vits args + parser.add_argument( + "--vits_model_to_use", + default="vctk", + help="Specify the model to use. Default is 'vctk'.", + ) + parser.add_argument( + "--vits_speaker_id", + type=int, + default=12, + help="Specify the speaker ID. Default is 6.", + ) + parser.add_argument( + "--vits_noise_scale", + type=float, + default=0.667, + help="Specify the noise scale. Default is 0.667.", + ) + parser.add_argument( + "--vits_noise_scale_w", + type=float, + default=0.8, + help="Specify the noise scale w. Default is 0.8.", + ) + parser.add_argument( + "--vits_length_scale", + type=float, + default=1, + help="Specify the length scale. Default is 1.", + ) + parser.add_argument( + "--vits_seed", + type=int, + default=None, + help="Specify the seed (set to None if no seed). Default is 1337.", + ) + parser.add_argument( + "--vits_num_channels", + type=int, + default=1, + help="Specify the number of audio output channels. Default is 1.", + ) + parser.add_argument( + "--vits_sample_width", + type=int, + default=2, + help="Specify the number of bytes per sample, adjust if necessary. Default is 2.", + ) + parser.add_argument( + "--vits_emotion_path", + type=Path, + default=None, + help="Specify the path to emotion reference.", + ) + parser.add_argument( + "--vits_estimate_max_y_length", + type=str, + default=False, + help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.", + ) + parser.add_argument( + "--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary." + ) - s = time.perf_counter() - is_listening_event.set() - prev_text = None - while True: - for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()]) - txt = transcribe_waveform(model, enc, [total], truncate=True) - print(txt, end="\r") - if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue - if prev_text is not None and prev_text == txt: - is_listening_event.clear() - break - prev_text = txt - print() # to avoid llama printing on the same line - log.append(f"{user_delim.capitalize()}: {txt}") + # conversation args + parser.add_argument( + "--max_sentence_length", + type=int, + default=20, + help="Max words in one sentence to pass to vits", + ) - # Generate with llama - with Timing("llama generation: "): - outputted, start_pos, response = llama_generate( - llama, toks, outputted, txt, start_pos, - user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature, - max_tokens=args.llama_count + args = parser.parse_args() + + # Init models + model, enc = init_whisper(args.whisper_model_name) + synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits( + args.vits_model_to_use, + args.vits_emotion_path, + args.vits_speaker_id, + args.vits_seed, + ) + + # Download tinyllama chat as a default model + if args.llama_model is None: + args.llama_model = fetch( + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", + "tinyllamachat.safetensors", ) - log.append(f"{resp_delim.capitalize()}: {response}") + args.llama_gen = "tiny" + args.llama_size = "1B-Chat" + # Add 3 more tokens to the tokenizer + if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): + args.llama_tokenizer = create_fixed_tokenizer() + tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model" + llama = LLaMa.build( + args.llama_model, + tokenizer_path, + args.llama_gen, + args.llama_size, + args.llama_quantize, + ) + toks, user_delim, resp_delim, start_pos, outputted = llama_prepare( + llama, args.llama_temperature, args.llama_pre_prompt_path + ) - # Convert to voice - with Timing("tts: "): - sentences = nltk.sent_tokenize(response.replace('"', "")) - for i in sentences: - total = np.array([], dtype=np.int16) - for j in chunks(i.split(), args.max_sentence_length): - audio_data = tts( - " ".join(j), synth, hps, emotion_embedding, - args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale, - args.vits_noise_scale_w, args.vits_length_scale, - args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers - ) - total = np.concatenate([total, audio_data]) - out_q.put(total.tobytes()) - while out_counter.value < len(sentences): continue - log.append(f"Total: {time.perf_counter() - s}") + # Start child process for mic input + q = mp.Queue() + is_listening_event = mp.Event() + p = mp.Process( + target=listener, + args=( + q, + is_listening_event, + ), + ) + p.daemon = True + p.start() + + # Start child process for speaker output + out_q = mp.Queue() + out_counter = mp.Value("i", 0) + out_p = mp.Process( + target=mp_output_stream, + args=( + out_q, + out_counter, + args.vits_num_channels, + hps.data.sampling_rate, + ), + ) + out_p.daemon = True + out_p.start() + + # JIT tts + for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]: + tts( + i, + synth, + hps, + emotion_embedding, + args.vits_speaker_id, + args.vits_model_to_use, + args.vits_noise_scale, + args.vits_noise_scale_w, + args.vits_length_scale, + args.vits_estimate_max_y_length, + text_mapper, + model_has_multiple_speakers, + ) + + # Start the pipeline + with log_writer() as log: + while True: + tokens = [ + enc._special_tokens["<|startoftranscript|>"], + enc._special_tokens["<|notimestamps|>"], + ] + total = np.array([]) + out_counter.value = 0 + + s = time.perf_counter() + is_listening_event.set() + prev_text = None + while True: + for _ in range(RATE // CHUNK): + total = np.concatenate([total, q.get()]) + txt = transcribe_waveform(model, enc, [total], truncate=True) + print(txt, end="\r") + if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): + continue + if prev_text is not None and prev_text == txt: + is_listening_event.clear() + break + prev_text = txt + print() # to avoid llama printing on the same line + log.append(f"{user_delim.capitalize()}: {txt}") + + # Generate with llama + with Timing("llama generation: "): + outputted, start_pos, response = llama_generate( + llama, + toks, + outputted, + txt, + start_pos, + user_delim=user_delim, + resp_delim=resp_delim, + temperature=args.llama_temperature, + max_tokens=args.llama_count, + ) + log.append(f"{resp_delim.capitalize()}: {response}") + + # Convert to voice + with Timing("tts: "): + sentences = nltk.sent_tokenize(response.replace('"', "")) + for i in sentences: + total = np.array([], dtype=np.int16) + for j in chunks(i.split(), args.max_sentence_length): + audio_data = tts( + " ".join(j), + synth, + hps, + emotion_embedding, + args.vits_speaker_id, + args.vits_model_to_use, + args.vits_noise_scale, + args.vits_noise_scale_w, + args.vits_length_scale, + args.vits_estimate_max_y_length, + text_mapper, + model_has_multiple_speakers, + ) + total = np.concatenate([total, audio_data]) + out_q.put(total.tobytes()) + while out_counter.value < len(sentences): + continue + log.append(f"Total: {time.perf_counter() - s}") diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 1e5381fb6..d0dd29d70 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -11,78 +11,98 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, fetch, Timing from tinygrad.jit import TinyJit from extra.models.efficientnet import EfficientNet + np.set_printoptions(suppress=True) # TODO: you should be able to put these in the jitted function bias = Tensor([0.485, 0.456, 0.406]) scale = Tensor([0.229, 0.224, 0.225]) + @TinyJit def _infer(model, img): - img = img.permute((2,0,1)) - img = img / 255.0 - img = img - bias.reshape((1,-1,1,1)) - img = img / scale.reshape((1,-1,1,1)) - return model.forward(img).realize() + img = img.permute((2, 0, 1)) + img = img / 255.0 + img = img - bias.reshape((1, -1, 1, 1)) + img = img / scale.reshape((1, -1, 1, 1)) + return model.forward(img).realize() + def infer(model, img): - # preprocess image - aspect_ratio = img.size[0] / img.size[1] - img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) + # preprocess image + aspect_ratio = img.size[0] / img.size[1] + img = img.resize( + (int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0))) + ) - img = np.array(img) - y0,x0=(np.asarray(img.shape)[:2]-224)//2 - retimg = img = img[y0:y0+224, x0:x0+224] + img = np.array(img) + y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2 + retimg = img = img[y0 : y0 + 224, x0 : x0 + 224] - # if you want to look at the image - """ + # if you want to look at the image + """ import matplotlib.pyplot as plt plt.imshow(img) plt.show() """ - # run the net - out = _infer(model, Tensor(img.astype("float32"))).numpy() + # run the net + out = _infer(model, Tensor(img.astype("float32"))).numpy() - # if you want to look at the outputs - """ + # if you want to look at the outputs + """ import matplotlib.pyplot as plt plt.plot(out[0]) plt.show() """ - return out, retimg + return out, retimg + if __name__ == "__main__": - # instantiate my net - model = EfficientNet(getenv("NUM", 0)) - model.load_from_pretrained() + # instantiate my net + model = EfficientNet(getenv("NUM", 0)) + model.load_from_pretrained() - # category labels - lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) + # category labels + lbls = ast.literal_eval( + fetch( + "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt" + ).read_text() + ) - # load image and preprocess - url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg" - if url == 'webcam': - import cv2 - cap = cv2.VideoCapture(0) - cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - while 1: - _ = cap.grab() # discard one frame to circumvent capture buffering - ret, frame = cap.read() - img = Image.fromarray(frame[:, :, [2,1,0]]) - lt = time.monotonic_ns() - out, retimg = infer(model, img) - print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)]) - SCALE = 3 - simg = cv2.resize(retimg, (224*SCALE, 224*SCALE)) - retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR) - cv2.imshow('capture', retimg) - if cv2.waitKey(1) & 0xFF == ord('q'): - break - cap.release() - cv2.destroyAllWindows() - else: - img = Image.open(fetch(url)) - with Timing("did inference in "): - out, _ = infer(model, img) - print(np.argmax(out), np.max(out), lbls[np.argmax(out)]) + # load image and preprocess + url = ( + sys.argv[1] + if len(sys.argv) >= 2 + else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg" + ) + if url == "webcam": + import cv2 + + cap = cv2.VideoCapture(0) + cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + while 1: + _ = cap.grab() # discard one frame to circumvent capture buffering + ret, frame = cap.read() + img = Image.fromarray(frame[:, :, [2, 1, 0]]) + lt = time.monotonic_ns() + out, retimg = infer(model, img) + print( + f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", + np.argmax(out), + np.max(out), + lbls[np.argmax(out)], + ) + SCALE = 3 + simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE)) + retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR) + cv2.imshow("capture", retimg) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + cap.release() + cv2.destroyAllWindows() + else: + img = Image.open(fetch(url)) + with Timing("did inference in "): + out, _ = infer(model, img) + print(np.argmax(out), np.max(out), lbls[np.argmax(out)]) diff --git a/examples/f16_w_uint32.py b/examples/f16_w_uint32.py index bf281661e..d4abbfe20 100644 --- a/examples/f16_w_uint32.py +++ b/examples/f16_w_uint32.py @@ -3,40 +3,47 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes from tinygrad import Device + # TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul def bit_extract(x, s, e) -> Tensor: - # extract the top bits we don't want - top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1)) - x = (x - top_bits) / (1< Tensor: - if mask is not None: - # no symbolic shape qkv when consuming prompts - start_pos = start_pos.val + def __call__( + self, x: Tensor, start_pos: Variable, mask: Optional[Tensor] + ) -> Tensor: + if mask is not None: + # no symbolic shape qkv when consuming prompts + start_pos = start_pos.val - xqkv = self.c_attn(x) - xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)] - bsz, seqlen, n_heads, head_dim = xq.shape + xqkv = self.c_attn(x) + xq, xk, xv = [ + xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape( + xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim + ) + for i in range(3) + ] + bsz, seqlen, n_heads, head_dim = xq.shape - # create kv cache - if not hasattr(self, "cache_k"): - self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim) - if HALF: - self.cache_k = self.cache_k.half() - self.cache_v = self.cache_v.half() + # create kv cache + if not hasattr(self, "cache_k"): + self.cache_k, self.cache_v = Tensor.zeros( + bsz, MAX_CONTEXT, self.n_heads, self.head_dim + ), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim) + if HALF: + self.cache_k = self.cache_k.half() + self.cache_v = self.cache_v.half() - keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) - values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) + keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) + values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) - # update the cache - self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize() - self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize() + # update the cache + self.cache_k.assign( + keys.pad( + (None, (0, MAX_CONTEXT - start_pos - seqlen), None, None) + ).contiguous() + ).realize() + self.cache_v.assign( + values.pad( + (None, (0, MAX_CONTEXT - start_pos - seqlen), None, None) + ).contiguous() + ).realize() + + xq, keys, values = ( + xq.transpose(1, 2), + keys.transpose(1, 2), + values.transpose(1, 2), + ) + return self.c_proj( + xq.scaled_dot_product_attention(keys, values, mask) + .transpose(1, 2) + .reshape(bsz, seqlen, -1) + ) - xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) - return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)) class FeedForward: - def __init__(self, dim, hidden_dim): - self.c_fc = Linear(dim, hidden_dim, bias=True) - self.c_proj = Linear(hidden_dim, dim, bias=True) + def __init__(self, dim, hidden_dim): + self.c_fc = Linear(dim, hidden_dim, bias=True) + self.c_proj = Linear(hidden_dim, dim, bias=True) + + def __call__(self, x: Tensor) -> Tensor: + return self.c_proj(self.c_fc(x).gelu()) - def __call__(self, x:Tensor) -> Tensor: - return self.c_proj(self.c_fc(x).gelu()) class TransformerBlock: - def __init__(self, dim, n_heads, norm_eps): - self.attn = Attention(dim, n_heads) - self.mlp = FeedForward(dim, 4*dim) - self.ln_1 = LayerNorm(dim, norm_eps) - self.ln_2 = LayerNorm(dim, norm_eps) + def __init__(self, dim, n_heads, norm_eps): + self.attn = Attention(dim, n_heads) + self.mlp = FeedForward(dim, 4 * dim) + self.ln_1 = LayerNorm(dim, norm_eps) + self.ln_2 = LayerNorm(dim, norm_eps) + + def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]): + h = x + self.attn(self.ln_1(x), start_pos, mask) + return h + self.mlp(self.ln_2(h)) - def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]): - h = x + self.attn(self.ln_1(x), start_pos, mask) - return (h + self.mlp(self.ln_2(h))) class Transformer: - def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024): - self.wte = Embedding(vocab_size, dim) - self.wpe = Embedding(max_seq_len, dim) - self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)] - self.ln_f = LayerNorm(dim, norm_eps) - self.lm_head = Linear(dim, vocab_size, bias=False) - self.forward_jit = TinyJit(self.forward) + def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024): + self.wte = Embedding(vocab_size, dim) + self.wpe = Embedding(max_seq_len, dim) + self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)] + self.ln_f = LayerNorm(dim, norm_eps) + self.lm_head = Linear(dim, vocab_size, bias=False) + self.forward_jit = TinyJit(self.forward) - def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0): - if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize() - _bsz, seqlen = tokens.shape + def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0): + if not hasattr(self, "allpos"): + self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize() + _bsz, seqlen = tokens.shape - # NOTE: cannot convert token indices into half due to precision - tok_emb = self.wte(tokens) - pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen)))) - h = tok_emb + pos_emb + # NOTE: cannot convert token indices into half due to precision + tok_emb = self.wte(tokens) + pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen)))) + h = tok_emb + pos_emb - mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None + mask = ( + Tensor.full((1, 1, seqlen, start_pos.val + seqlen), float("-inf")) + .triu(start_pos.val + 1) + .realize() + if seqlen > 1 + else None + ) - if HALF: - h = h.half() - if mask is not None: mask = mask.half() + if HALF: + h = h.half() + if mask is not None: + mask = mask.half() - for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask) + for hi in self.h: + h = hi(h, start_pos=start_pos, mask=mask) - logits = self.lm_head(self.ln_f(h)) - # NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead - return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize() + logits = self.lm_head(self.ln_f(h)) + # NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead + return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize() + + # TODO: fix empty token + def __call__( + self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0 + ) -> Tensor: + return ( + self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward + )(tokens, start_pos, temperature) - # TODO: fix empty token - def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor: - return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature) VOCAB_SIZE = 50257 MODEL_PARAMS = { - 'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params - 'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params - 'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params - 'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params + "gpt2": dict( + n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE + ), # 124M params + "gpt2-medium": dict( + n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE + ), # 350M params + "gpt2-large": dict( + n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE + ), # 774M params + "gpt2-xl": dict( + n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE + ), # 1558M params } + class GPT2: - @staticmethod - def build(model_size="gpt2"): - tokenizer = tiktoken.get_encoding("gpt2") + @staticmethod + def build(model_size="gpt2"): + tokenizer = tiktoken.get_encoding("gpt2") - model = Transformer(**MODEL_PARAMS[model_size]) - weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin')) - # special treatment for the Conv1D weights we need to transpose - transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] - for k in weights.keys(): - if any(k.endswith(w) for w in transposed): - weights[k] = Tensor(weights[k].numpy().T) - # lm head and wte are tied - weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy()) + model = Transformer(**MODEL_PARAMS[model_size]) + weights = torch_load( + fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin") + ) + # special treatment for the Conv1D weights we need to transpose + transposed = [ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + ] + for k in weights.keys(): + if any(k.endswith(w) for w in transposed): + weights[k] = Tensor(weights[k].numpy().T) + # lm head and wte are tied + weights["lm_head.weight"] = Tensor(weights["wte.weight"].numpy()) - load_state_dict(model, weights) - return GPT2(model, tokenizer) + load_state_dict(model, weights) + return GPT2(model, tokenizer) - def __init__(self, model, tokenizer): - self.model = model - self.tokenizer = tokenizer + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer + + def greedy_until( + self, + prompt: str, + max_length: int, + temperature: float, + timing: bool = False, + batch_size: int = 1, + ): + prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"}) + toks = [prompt_tokens[:] for _ in range(batch_size)] + start_pos = 0 + for _ in trange(max_length, disable=(timing == True)): + GlobalCounters.reset() + if timing: + print("") + st = GlobalCounters.time_sum_s + with Timing("total ", enabled=timing): + with Timing( + "ran model in ", + on_exit=( + lambda et: ( + f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" + if DEBUG >= 2 + else "" + ) + + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB" + + ( + f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" + if DEBUG >= 2 + else "" + ) + ) + if DEBUG + else None, + enabled=timing, + ): + probs = self.model( + Tensor([x[start_pos:] for x in toks]), + Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind( + start_pos + ), + temperature, + ) + # TODO: fix JIT rand so we can put this in the JIT + tok = probs.multinomial().flatten().numpy().tolist() + start_pos = len(toks[0]) + for i, t in enumerate(tok): + toks[i].append(t) + output = [self.tokenizer.decode(x) for x in toks] + return output - def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1): - prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"}) - toks = [prompt_tokens[:] for _ in range(batch_size)] - start_pos = 0 - for _ in trange(max_length, disable=(timing==True)): - GlobalCounters.reset() - if timing: print("") - st = GlobalCounters.time_sum_s - with Timing("total ", enabled=timing): - with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ - f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ - (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing): - probs = self.model(Tensor([x[start_pos:] for x in toks]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature) - # TODO: fix JIT rand so we can put this in the JIT - tok = probs.multinomial().flatten().numpy().tolist() - start_pos = len(toks[0]) - for i,t in enumerate(tok): toks[i].append(t) - output = [self.tokenizer.decode(x) for x in toks] - return output # **** main code **** if __name__ == "__main__": - Tensor.no_grad = True - print(f"using {Device.DEFAULT} backend") + Tensor.no_grad = True + print(f"using {Device.DEFAULT} backend") - parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with") - parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate") - parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax") - parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]") - parser.add_argument('--timing', action='store_true', help="Print timing per token") - parser.add_argument('--seed', type=int, help="Set the random seed") - parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size") - parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens") - parser.add_argument('--noshow', action='store_true', help="Don't show the output") - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Run GPT2 in tinygrad", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--prompt", + type=str, + default="What is the answer to life, the universe, and everything?", + help="Phrase to start with", + ) + parser.add_argument( + "--count", type=int, default=100, help="Max number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature in the softmax" + ) + parser.add_argument( + "--model_size", + type=str, + default="gpt2-medium", + help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]", + ) + parser.add_argument("--timing", action="store_true", help="Print timing per token") + parser.add_argument("--seed", type=int, help="Set the random seed") + parser.add_argument( + "--batch_size", type=int, default=1, help="Set the input batch size" + ) + parser.add_argument( + "--benchmark", + type=int, + default=-1, + help="Benchmark GPT with the given number of tokens", + ) + parser.add_argument("--noshow", action="store_true", help="Don't show the output") + args = parser.parse_args() - if args.seed is not None: - Tensor._seed = args.seed - np.random.seed(args.seed) + if args.seed is not None: + Tensor._seed = args.seed + np.random.seed(args.seed) - print(f"using {args.model_size}") - gpt2 = GPT2.build(args.model_size) + print(f"using {args.model_size}") + gpt2 = GPT2.build(args.model_size) - if HALF: - for l in get_state_dict(gpt2).values(): - l.assign(l.cast(dtypes.float16).realize()) + if HALF: + for l in get_state_dict(gpt2).values(): + l.assign(l.cast(dtypes.float16).realize()) - if args.benchmark != -1: - gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize() - else: - texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size) - if not args.noshow: - print('Generating text...') - if len(texts) == 1: print(texts[0]) - else: - for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text) \ No newline at end of file + if args.benchmark != -1: + gpt2.model( + Tensor.rand(args.batch_size, args.benchmark), + Variable("a", 0, MAX_CONTEXT).bind(0), + ).realize() + else: + texts = gpt2.greedy_until( + args.prompt, + args.count, + args.temperature, + timing=args.timing, + batch_size=args.batch_size, + ) + if not args.noshow: + print("Generating text...") + if len(texts) == 1: + print(texts[0]) + else: + for i, text in enumerate(texts): + print(colored(f"Response {i}:", "green"), text) diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 720b9b089..68d69fc78 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -11,61 +11,75 @@ from tinygrad.shape.symbolic import sym_infer if __name__ == "__main__": - mdl = ResNet50() - seen = set() + mdl = ResNet50() + seen = set() - # the device we are optimizing for - device: Compiled = Device[Device.DEFAULT] - print(f"optimizing for {Device.DEFAULT}") + # the device we are optimizing for + device: Compiled = Device[Device.DEFAULT] + print(f"optimizing for {Device.DEFAULT}") - # first model run to init the weights, they are saved in seen - mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen) + # first model run to init the weights, they are saved in seen + mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen) - # run model again to get only what changes, these are the kernels of the model - x = Tensor.empty(64, 3, 224, 224) - out = mdl(x) - sched = out.lazydata.schedule(seen) - sched = [x for x in sched if x.ast.op not in LoadOps] + # run model again to get only what changes, these are the kernels of the model + x = Tensor.empty(64, 3, 224, 224) + out = mdl(x) + sched = out.lazydata.schedule(seen) + sched = [x for x in sched if x.ast.op not in LoadOps] - # focus on one kernel - if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1] + # focus on one kernel + if getenv("KERNEL", -1) >= 0: + sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1] - # work with the schedule - total_tm = 0 - running_gflops = 0 - for i,si in enumerate(sched): - rawbufs = bufs_from_lin(Linearizer(si.ast)) + # work with the schedule + total_tm = 0 + running_gflops = 0 + for i, si in enumerate(sched): + rawbufs = bufs_from_lin(Linearizer(si.ast)) - # "linearize" the op into uops in different ways - lins:List[Linearizer] = [] + # "linearize" the op into uops in different ways + lins: List[Linearizer] = [] - # always try hand coded opt - lin = Linearizer(si.ast, device.linearizer_opts) - lin.hand_coded_optimizations() - lins.append(lin) + # always try hand coded opt + lin = Linearizer(si.ast, device.linearizer_opts) + lin.hand_coded_optimizations() + lins.append(lin) - # maybe try tensor cores - lin = Linearizer(si.ast, device.linearizer_opts) - if lin.apply_tensor_cores(): - lins.append(lin) + # maybe try tensor cores + lin = Linearizer(si.ast, device.linearizer_opts) + if lin.apply_tensor_cores(): + lins.append(lin) - # try a beam search - if getenv("BEAM"): - lin = Linearizer(si.ast, device.linearizer_opts) - lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))) - lins.append(lin) + # try a beam search + if getenv("BEAM"): + lin = Linearizer(si.ast, device.linearizer_opts) + lin = beam_search( + lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)) + ) + lins.append(lin) - # benchmark the programs - choices = [] - for lin in lins: - tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) - gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm - choices.append((tm, gflops, lin.linearize())) + # benchmark the programs + choices = [] + for lin in lins: + tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) + gflops = ( + sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)}) + * 1e-9 + / tm + ) + choices.append((tm, gflops, lin.linearize())) - # print all kernels - if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") - tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0] - print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") - total_tm += tm - running_gflops += gflops * tm - print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS") + # print all kernels + if DEBUG >= 1: + print( + f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS" + ) + tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0] + print( + f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS" + ) + total_tm += tm + running_gflops += gflops * tm + print( + f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS" + ) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 7079de63c..0aa13f85d 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -2,10 +2,11 @@ # setup for distributed from extra import dist from tinygrad.helpers import getenv, dtypes + if __name__ == "__main__": - if getenv("DIST"): - dist.preinit() - from extra.dist import collectives + if getenv("DIST"): + dist.preinit() + from extra.dist import collectives # tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py # https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/ @@ -24,427 +25,594 @@ from tinygrad.shape.symbolic import Node from extra.lr_scheduler import OneCycleLR from tinygrad.jit import TinyJit -BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000) +BS, EVAL_BS, STEPS = getenv("BS", 512), getenv("EVAL_BS", 500), getenv("STEPS", 1000) if getenv("HALF", 0): - Tensor.default_type = dtypes.float16 - np_dtype: Type[Union[np.float16, np.float32]] = np.float16 + Tensor.default_type = dtypes.float16 + np_dtype: Type[Union[np.float16, np.float32]] = np.float16 else: - Tensor.default_type = dtypes.float32 - np_dtype = np.float32 + Tensor.default_type = dtypes.float32 + np_dtype = np.float32 + class BatchNorm(nn.BatchNorm2d): - def __init__(self, num_features): - super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True) - self.weight.requires_grad = False - self.bias.requires_grad = True + def __init__(self, num_features): + super().__init__( + num_features, + track_running_stats=False, + eps=1e-12, + momentum=0.85, + affine=True, + ) + self.weight.requires_grad = False + self.bias.requires_grad = True + class ConvGroup: - def __init__(self, channels_in, channels_out): - self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False) - self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False) + def __init__(self, channels_in, channels_out): + self.conv1 = nn.Conv2d( + channels_in, channels_out, kernel_size=3, padding=1, bias=False + ) + self.conv2 = nn.Conv2d( + channels_out, channels_out, kernel_size=3, padding=1, bias=False + ) - self.norm1 = BatchNorm(channels_out) - self.norm2 = BatchNorm(channels_out) + self.norm1 = BatchNorm(channels_out) + self.norm2 = BatchNorm(channels_out) - def __call__(self, x): - x = self.conv1(x) - x = x.max_pool2d(2) - x = x.float() - x = self.norm1(x) - x = x.cast(Tensor.default_type) - x = x.gelu() - residual = x - x = self.conv2(x) - x = x.float() - x = self.norm2(x) - x = x.cast(Tensor.default_type) - x = x.gelu() + def __call__(self, x): + x = self.conv1(x) + x = x.max_pool2d(2) + x = x.float() + x = self.norm1(x) + x = x.cast(Tensor.default_type) + x = x.gelu() + residual = x + x = self.conv2(x) + x = x.float() + x = self.norm2(x) + x = x.cast(Tensor.default_type) + x = x.gelu() + + return x + residual - return x + residual class SpeedyResNet: - def __init__(self, W): - self.whitening = W - self.net = [ - nn.Conv2d(12, 32, kernel_size=1, bias=False), - lambda x: x.gelu(), - ConvGroup(32, 64), - ConvGroup(64, 256), - ConvGroup(256, 512), - lambda x: x.max((2,3)), - nn.Linear(512, 10, bias=False), - lambda x: x.mul(1./9) - ] + def __init__(self, W): + self.whitening = W + self.net = [ + nn.Conv2d(12, 32, kernel_size=1, bias=False), + lambda x: x.gelu(), + ConvGroup(32, 64), + ConvGroup(64, 256), + ConvGroup(256, 512), + lambda x: x.max((2, 3)), + nn.Linear(512, 10, bias=False), + lambda x: x.mul(1.0 / 9), + ] + + def __call__(self, x, training=True): + # pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with + # TODO: remove the pad but instead let the kernel optimizer itself + forward = ( + lambda x: x.conv2d(self.whitening).pad2d((1, 0, 0, 1)).sequential(self.net) + ) + return ( + forward(x) if training else forward(x) * 0.5 + forward(x[..., ::-1]) * 0.5 + ) - def __call__(self, x, training=True): - # pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with - # TODO: remove the pad but instead let the kernel optimizer itself - forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net) - return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5 def train_cifar(): - - # hyper-parameters were exactly the same as the original repo - bias_scaler = 58 - hyp: Dict[str, Any] = { - 'seed' : 209, - 'opt': { - 'bias_lr': 1.76 * bias_scaler/512, - 'non_bias_lr': 1.76 / 512, - 'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler, - 'non_bias_decay': 1.08 * 6.45e-4 * BS, - 'final_lr_ratio': 0.025, - 'initial_div_factor': 1e16, - 'label_smoothing': 0.20, - 'momentum': 0.85, - 'percent_start': 0.23, - 'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16) - }, - 'net': { - 'kernel_size': 2, # kernel size for the whitening layer - 'cutmix_size': 3, - 'cutmix_steps': 499, - 'pad_amount': 2 - }, - 'ema': { - 'steps': 399, - 'decay_base': .95, - 'decay_pow': 1.6, - 'every_n_steps': 5, + # hyper-parameters were exactly the same as the original repo + bias_scaler = 58 + hyp: Dict[str, Any] = { + "seed": 209, + "opt": { + "bias_lr": 1.76 * bias_scaler / 512, + "non_bias_lr": 1.76 / 512, + "bias_decay": 1.08 * 6.45e-4 * BS / bias_scaler, + "non_bias_decay": 1.08 * 6.45e-4 * BS, + "final_lr_ratio": 0.025, + "initial_div_factor": 1e16, + "label_smoothing": 0.20, + "momentum": 0.85, + "percent_start": 0.23, + "loss_scale_scaler": 1.0 / 128, # (range: ~1/512 - 16+, 1/128 w/ FP16) + }, + "net": { + "kernel_size": 2, # kernel size for the whitening layer + "cutmix_size": 3, + "cutmix_steps": 499, + "pad_amount": 2, + }, + "ema": { + "steps": 399, + "decay_base": 0.95, + "decay_pow": 1.6, + "every_n_steps": 5, + }, } - } - def set_seed(seed): - Tensor.manual_seed(getenv('SEED', seed)) - random.seed(getenv('SEED', seed)) + def set_seed(seed): + Tensor.manual_seed(getenv("SEED", seed)) + random.seed(getenv("SEED", seed)) - # ========== Model ========== - # NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually - def whitening(X, kernel_size=hyp['net']['kernel_size']): - def _cov(X): - X = X/np.sqrt(X.shape[0] - 1) - return X.T @ X + # ========== Model ========== + # NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually + def whitening(X, kernel_size=hyp["net"]["kernel_size"]): + def _cov(X): + X = X / np.sqrt(X.shape[0] - 1) + return X.T @ X - def _patches(data, patch_size=(kernel_size,kernel_size)): - h, w = patch_size - c = data.shape[1] - axis: SupportsIndex = (2, 3) # type: ignore - return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w)) + def _patches(data, patch_size=(kernel_size, kernel_size)): + h, w = patch_size + c = data.shape[1] + axis: SupportsIndex = (2, 3) # type: ignore + return ( + np.lib.stride_tricks.sliding_window_view( + data, window_shape=(h, w), axis=axis + ) + .transpose((0, 3, 2, 1, 4, 5)) + .reshape((-1, c, h, w)) + ) - def _eigens(patches): - n,c,h,w = patches.shape - Σ = _cov(patches.reshape(n, c*h*w)) - Λ, V = np.linalg.eigh(Σ, UPLO='U') - return np.flip(Λ, 0), np.flip(V.T.reshape(c*h*w, c, h, w), 0) + def _eigens(patches): + n, c, h, w = patches.shape + Σ = _cov(patches.reshape(n, c * h * w)) + Λ, V = np.linalg.eigh(Σ, UPLO="U") + return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0) - Λ, V = _eigens(_patches(X.numpy())) - W = V/np.sqrt(Λ+1e-2)[:,None,None,None] + Λ, V = _eigens(_patches(X.numpy())) + W = V / np.sqrt(Λ + 1e-2)[:, None, None, None] - return Tensor(W.astype(np_dtype), requires_grad=False) + return Tensor(W.astype(np_dtype), requires_grad=False) - # ========== Loss ========== - def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor: - divisor = y.shape[1] - assert not isinstance(divisor, Node), "sint not supported as divisor" - y = (1 - label_smoothing)*y + label_smoothing / divisor - if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1) - if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum() - return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean() + # ========== Loss ========== + def cross_entropy( + x: Tensor, y: Tensor, reduction: str = "mean", label_smoothing: float = 0.0 + ) -> Tensor: + divisor = y.shape[1] + assert not isinstance(divisor, Node), "sint not supported as divisor" + y = (1 - label_smoothing) * y + label_smoothing / divisor + if reduction == "none": + return -x.log_softmax(axis=1).mul(y).sum(axis=1) + if reduction == "sum": + return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum() + return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean() - # ========== Preprocessing ========== - # TODO currently this only works for RGB in format of NxCxHxW and pads the HxW - # implemented in recursive fashion but figuring out how to switch indexing dim - # during the loop was a bit tricky - def pad_reflect(X, size=2) -> Tensor: - padding = ((0,0),(0,0),(size,size),(size,size)) - p = padding[3] - s = X.shape[3] + # ========== Preprocessing ========== + # TODO currently this only works for RGB in format of NxCxHxW and pads the HxW + # implemented in recursive fashion but figuring out how to switch indexing dim + # during the loop was a bit tricky + def pad_reflect(X, size=2) -> Tensor: + padding = ((0, 0), (0, 0), (size, size), (size, size)) + p = padding[3] + s = X.shape[3] - X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0))) - X = X.pad(((0,0),(0,0),(0,0),p)) + X_lr + X_lr = X[..., :, 1 : 1 + p[0]].flip(3).pad( + ((0, 0), (0, 0), (0, 0), (0, s + p[0])) + ) + X[..., :, -1 - p[1] : -1].flip(3).pad( + ((0, 0), (0, 0), (0, 0), (s + p[1], 0)) + ) + X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr - p = padding[2] - s = X.shape[2] - X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0))) - X = X.pad(((0,0),(0,0),p,(0,0))) + X_lr + p = padding[2] + s = X.shape[2] + X_lr = X[..., 1 : 1 + p[0], :].flip(2).pad( + ((0, 0), (0, 0), (0, s + p[0]), (0, 0)) + ) + X[..., -1 - p[1] : -1, :].flip(2).pad( + ((0, 0), (0, 0), (s + p[1], 0), (0, 0)) + ) + X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr - return X + return X - # return a binary mask in the format of BS x C x H x W where H x W contains a random square mask - def make_square_mask(shape, mask_size) -> Tensor: - is_even = int(mask_size % 2 == 0) - center_max = shape[-2]-mask_size//2-is_even - center_min = mask_size//2-is_even - center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor() - center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor() - d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1)) - d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1)) - d_x =(d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2) - d_y =(d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2) - mask = d_y * d_x - return mask + # return a binary mask in the format of BS x C x H x W where H x W contains a random square mask + def make_square_mask(shape, mask_size) -> Tensor: + is_even = int(mask_size % 2 == 0) + center_max = shape[-2] - mask_size // 2 - is_even + center_min = mask_size // 2 - is_even + center_x = ( + Tensor.rand(shape[0]) * (center_max - center_min) + center_min + ).floor() + center_y = ( + Tensor.rand(shape[0]) * (center_max - center_min) + center_min + ).floor() + d_x = Tensor.arange(0, shape[-1]).reshape( + (1, 1, 1, shape[-1]) + ) - center_x.reshape((-1, 1, 1, 1)) + d_y = Tensor.arange(0, shape[-2]).reshape( + (1, 1, shape[-2], 1) + ) - center_y.reshape((-1, 1, 1, 1)) + d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2) + d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2) + mask = d_y * d_x + return mask - def random_crop(X:Tensor, crop_size=32): - mask = make_square_mask(X.shape, crop_size) - mask = mask.repeat((1,3,1,1)) - X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)]) - return X_cropped.reshape((-1, 3, crop_size, crop_size)) + def random_crop(X: Tensor, crop_size=32): + mask = make_square_mask(X.shape, crop_size) + mask = mask.repeat((1, 3, 1, 1)) + X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)]) + return X_cropped.reshape((-1, 3, crop_size, crop_size)) - def cutmix(X:Tensor, Y:Tensor, mask_size=3): - # fill the square with randomly selected images from the same batch - mask = make_square_mask(X.shape, mask_size) - order = list(range(0, X.shape[0])) - random.shuffle(order) - X_patch = Tensor(X.numpy()[order,...]) - Y_patch = Tensor(Y.numpy()[order]) - X_cutmix = Tensor.where(mask, X_patch, X) - mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1]) - Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y - return X_cutmix, Y_cutmix + def cutmix(X: Tensor, Y: Tensor, mask_size=3): + # fill the square with randomly selected images from the same batch + mask = make_square_mask(X.shape, mask_size) + order = list(range(0, X.shape[0])) + random.shuffle(order) + X_patch = Tensor(X.numpy()[order, ...]) + Y_patch = Tensor(Y.numpy()[order]) + X_cutmix = Tensor.where(mask, X_patch, X) + mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1]) + Y_cutmix = mix_portion * Y_patch + (1.0 - mix_portion) * Y + return X_cutmix, Y_cutmix - # the operations that remain inside batch fetcher is the ones that involves random operations - def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool): - step, cnt = 0, 0 - while True: - st = time.monotonic() - X, Y = X_in, Y_in - order = list(range(0, X.shape[0])) - random.shuffle(order) - if is_train: - X = random_crop(X, crop_size=32) - X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR - if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size']) - X, Y = X.numpy(), Y.numpy() - et = time.monotonic() - print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})") - for i in range(0, X.shape[0], BS): - # pad the last batch - batch_end = min(i+BS, Y.shape[0]) - x = Tensor(X[order[batch_end-BS:batch_end],:]) - y = Tensor(Y[order[batch_end-BS:batch_end]]) - step += 1 - yield x, y - cnt += 1 - if not is_train: break + # the operations that remain inside batch fetcher is the ones that involves random operations + def fetch_batches(X_in: Tensor, Y_in: Tensor, BS: int, is_train: bool): + step, cnt = 0, 0 + while True: + st = time.monotonic() + X, Y = X_in, Y_in + order = list(range(0, X.shape[0])) + random.shuffle(order) + if is_train: + X = random_crop(X, crop_size=32) + X = Tensor.where( + Tensor.rand(X.shape[0], 1, 1, 1) < 0.5, X[..., ::-1], X + ) # flip LR + if step >= hyp["net"]["cutmix_steps"]: + X, Y = cutmix(X, Y, mask_size=hyp["net"]["cutmix_size"]) + X, Y = X.numpy(), Y.numpy() + et = time.monotonic() + print( + f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})" + ) + for i in range(0, X.shape[0], BS): + # pad the last batch + batch_end = min(i + BS, Y.shape[0]) + x = Tensor(X[order[batch_end - BS : batch_end], :]) + y = Tensor(Y[order[batch_end - BS : batch_end]]) + step += 1 + yield x, y + cnt += 1 + if not is_train: + break - transform = [ - lambda x: x / 255.0, - lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1)) - ] + transform = [ + lambda x: x / 255.0, + lambda x: ( + x.reshape((-1, 3, 32, 32)) - Tensor(cifar_mean).reshape((1, 3, 1, 1)) + ) + / Tensor(cifar_std).reshape((1, 3, 1, 1)), + ] - class modelEMA(): - def __init__(self, w, net): - # self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer - self.net_ema = SpeedyResNet(w) - for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()): - net_ema_param.requires_grad = False - net_ema_param.assign(net_param.numpy()) + class modelEMA: + def __init__(self, w, net): + # self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer + self.net_ema = SpeedyResNet(w) + for net_ema_param, net_param in zip( + get_state_dict(self.net_ema).values(), get_state_dict(net).values() + ): + net_ema_param.requires_grad = False + net_ema_param.assign(net_param.numpy()) + + @TinyJit + def update(self, net, decay): + # TODO with Tensor.no_grad() + Tensor.no_grad = True + for net_ema_param, (param_name, net_param) in zip( + get_state_dict(self.net_ema).values(), get_state_dict(net).items() + ): + # batchnorm currently is not being tracked + if not ("num_batches_tracked" in param_name) and not ( + "running" in param_name + ): + net_ema_param.assign( + net_ema_param.detach() * decay + + net_param.detach() * (1.0 - decay) + ).realize() + Tensor.no_grad = False + + set_seed(hyp["seed"]) + + # this import needs to be done here because this is running in a subprocess + from extra.dist import OOB + + assert OOB is not None or not getenv("DIST"), "OOB should be initialized" + rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1) + + X_train, Y_train, X_test, Y_test = fetch_cifar() + # load data and label into GPU and convert to dtype accordingly + X_train, X_test = ( + X_train.to(device=Device.DEFAULT).float(), + X_test.to(device=Device.DEFAULT).float(), + ) + Y_train, Y_test = ( + Y_train.to(device=Device.DEFAULT).float(), + Y_test.to(device=Device.DEFAULT).float(), + ) + # one-hot encode labels + Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test] + # preprocess data + X_train, X_test = X_train.sequential(transform), X_test.sequential(transform) + + # precompute whitening patches + W = whitening(X_train) + + # initialize model weights + model = SpeedyResNet(W) + + # padding is not timed in the original repo since it can be done all at once + X_train = pad_reflect(X_train, size=hyp["net"]["pad_amount"]) + + # Convert data and labels to the default dtype + X_train, Y_train, X_test, Y_test = ( + X_train.cast(Tensor.default_type), + Y_train.cast(Tensor.default_type), + X_test.cast(Tensor.default_type), + Y_test.cast(Tensor.default_type), + ) + + # parse the training params into bias and non-bias + params_dict = get_state_dict(model) + params_bias = [] + params_non_bias = [] + for params in params_dict: + if params_dict[params].requires_grad is not False: + if "bias" in params: + params_bias.append(params_dict[params]) + else: + params_non_bias.append(params_dict[params]) + + opt_bias = optim.SGD( + params_bias, + lr=0.01, + momentum=hyp["opt"]["momentum"], + nesterov=True, + weight_decay=hyp["opt"]["bias_decay"], + ) + opt_non_bias = optim.SGD( + params_non_bias, + lr=0.01, + momentum=hyp["opt"]["momentum"], + nesterov=True, + weight_decay=hyp["opt"]["non_bias_decay"], + ) + + # NOTE taken from the hlb_CIFAR repository, might need to be tuned + initial_div_factor = hyp["opt"]["initial_div_factor"] + final_lr_ratio = hyp["opt"]["final_lr_ratio"] + pct_start = hyp["opt"]["percent_start"] + lr_sched_bias = OneCycleLR( + opt_bias, + max_lr=hyp["opt"]["bias_lr"], + pct_start=pct_start, + div_factor=initial_div_factor, + final_div_factor=1.0 / (initial_div_factor * final_lr_ratio), + total_steps=STEPS, + ) + lr_sched_non_bias = OneCycleLR( + opt_non_bias, + max_lr=hyp["opt"]["non_bias_lr"], + pct_start=pct_start, + div_factor=initial_div_factor, + final_div_factor=1.0 / (initial_div_factor * final_lr_ratio), + total_steps=STEPS, + ) + + loss_batchsize_scaler = 512 / BS @TinyJit - def update(self, net, decay): - # TODO with Tensor.no_grad() - Tensor.no_grad = True - for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()): - # batchnorm currently is not being tracked - if not ("num_batches_tracked" in param_name) and not ("running" in param_name): - net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize() - Tensor.no_grad = False + def train_step_jitted(model, optimizer, lr_scheduler, X, Y): + out = model(X) + loss = ( + cross_entropy( + out, Y, reduction="none", label_smoothing=hyp["opt"]["label_smoothing"] + ) + .mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler) + .sum() + .div(hyp["opt"]["loss_scale_scaler"]) + ) - set_seed(hyp['seed']) + if not getenv("DISABLE_BACKWARD"): + # index 0 for bias and 1 for non-bias + optimizer[0].zero_grad() + optimizer[1].zero_grad() + loss.backward() - # this import needs to be done here because this is running in a subprocess - from extra.dist import OOB - assert OOB is not None or not getenv("DIST"), "OOB should be initialized" - rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1) + if getenv("DIST"): + # sync gradients across ranks + bucket, offset = [], 0 + for _, v in params_dict.items(): + if v.grad is not None: + bucket.append(v.grad.flatten()) + grads = collectives.allreduce(Tensor.cat(*bucket)) + for _, v in params_dict.items(): + if v.grad is not None: + v.grad.assign( + grads[offset : offset + v.grad.numel()].reshape( + *v.grad.shape + ) + ) + offset += v.grad.numel() - X_train, Y_train, X_test, Y_test = fetch_cifar() - # load data and label into GPU and convert to dtype accordingly - X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float() - Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float() - # one-hot encode labels - Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test] - # preprocess data - X_train, X_test = X_train.sequential(transform), X_test.sequential(transform) + optimizer[0].step() + optimizer[1].step() + lr_scheduler[0].step() + lr_scheduler[1].step() + return loss.realize() - # precompute whitening patches - W = whitening(X_train) + def eval_step(model, X, Y): + out = model(X, training=False) + loss = cross_entropy(out, Y, reduction="mean") + correct = out.argmax(axis=1) == Y.argmax(axis=1) + return correct.realize(), loss.realize() - # initialize model weights - model = SpeedyResNet(W) + eval_step_jitted = TinyJit(eval_step) + eval_step_ema_jitted = TinyJit(eval_step) - # padding is not timed in the original repo since it can be done all at once - X_train = pad_reflect(X_train, size=hyp['net']['pad_amount']) + # 97 steps in 2 seconds = 20ms / step + # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136 + # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68 + # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1 + # 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off - # Convert data and labels to the default dtype - X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type) + # https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june + # 136 TFLOPS is the theoretical max w float16 on 3080 Ti - # parse the training params into bias and non-bias - params_dict = get_state_dict(model) - params_bias = [] - params_non_bias = [] - for params in params_dict: - if params_dict[params].requires_grad is not False: - if 'bias' in params: - params_bias.append(params_dict[params]) - else: - params_non_bias.append(params_dict[params]) + model_ema: Optional[modelEMA] = None + projected_ema_decay_val = hyp["ema"]["decay_base"] ** hyp["ema"]["every_n_steps"] + i = 0 + batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) + with Tensor.train(): + st = time.monotonic() + while i <= STEPS: + if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1: + st_eval = time.monotonic() + # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True + corrects = [] + corrects_ema = [] + losses = [] + losses_ema = [] + for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False): + # further split batch if distributed + if getenv("DIST"): + Xt, Yt = ( + Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], + Yt.chunk(min(world_size, 5), 0)[min(rank, 4)], + ) - opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay']) - opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay']) + correct, loss = eval_step_jitted(model, Xt, Yt) + losses.append(loss.numpy().tolist()) + corrects.extend(correct.numpy().tolist()) + if model_ema: + correct_ema, loss_ema = eval_step_ema_jitted( + model_ema.net_ema, Xt, Yt + ) + losses_ema.append(loss_ema.numpy().tolist()) + corrects_ema.extend(correct_ema.numpy().tolist()) - # NOTE taken from the hlb_CIFAR repository, might need to be tuned - initial_div_factor = hyp['opt']['initial_div_factor'] - final_lr_ratio = hyp['opt']['final_lr_ratio'] - pct_start = hyp['opt']['percent_start'] - lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS) - lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS) + # collect accuracy across ranks + correct_sum, correct_len = sum(corrects), len(corrects) + if model_ema: + correct_sum_ema, correct_len_ema = sum(corrects_ema), len( + corrects_ema + ) + if getenv("DIST"): + if rank == 0: + for j in range(1, min(world_size, 5)): + if model_ema: + ( + recv_sum, + recv_len, + recv_sum_ema, + recv_len_ema, + ) = OOB.recv(j) + else: + recv_sum, recv_len = OOB.recv(j) + correct_sum += recv_sum + correct_len += recv_len + if model_ema: + correct_sum_ema += recv_sum_ema + correct_len_ema += recv_len_ema + elif rank < min(world_size, 5): + if model_ema: + OOB.send( + ( + correct_sum, + correct_len, + correct_sum_ema, + correct_len_ema, + ), + 0, + ) + else: + OOB.send((correct_sum, correct_len), 0) - loss_batchsize_scaler = 512/BS - @TinyJit - def train_step_jitted(model, optimizer, lr_scheduler, X, Y): - out = model(X) - loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) + # only rank 0 prints + if rank == 0: + acc = correct_sum / correct_len * 100.0 + if model_ema: + acc_ema = correct_sum_ema / correct_len_ema * 100.0 + print( + f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)" + ) + if model_ema: + print( + f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}" + ) - if not getenv("DISABLE_BACKWARD"): - # index 0 for bias and 1 for non-bias - optimizer[0].zero_grad() - optimizer[1].zero_grad() - loss.backward() - - if getenv("DIST"): - # sync gradients across ranks - bucket, offset = [], 0 - for _, v in params_dict.items(): - if v.grad is not None: bucket.append(v.grad.flatten()) - grads = collectives.allreduce(Tensor.cat(*bucket)) - for _, v in params_dict.items(): - if v.grad is not None: - v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape)) - offset += v.grad.numel() - - optimizer[0].step() - optimizer[1].step() - lr_scheduler[0].step() - lr_scheduler[1].step() - return loss.realize() - - def eval_step(model, X, Y): - out = model(X, training=False) - loss = cross_entropy(out, Y, reduction='mean') - correct = out.argmax(axis=1) == Y.argmax(axis=1) - return correct.realize(), loss.realize() - eval_step_jitted = TinyJit(eval_step) - eval_step_ema_jitted = TinyJit(eval_step) - - # 97 steps in 2 seconds = 20ms / step - # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136 - # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68 - # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1 - # 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off - - # https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june - # 136 TFLOPS is the theoretical max w float16 on 3080 Ti - - model_ema: Optional[modelEMA] = None - projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps'] - i = 0 - batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) - with Tensor.train(): - st = time.monotonic() - while i <= STEPS: - if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1: - st_eval = time.monotonic() - # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True - corrects = [] - corrects_ema = [] - losses = [] - losses_ema = [] - for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False): - # further split batch if distributed - if getenv("DIST"): - Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)] - - correct, loss = eval_step_jitted(model, Xt, Yt) - losses.append(loss.numpy().tolist()) - corrects.extend(correct.numpy().tolist()) - if model_ema: - correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt) - losses_ema.append(loss_ema.numpy().tolist()) - corrects_ema.extend(correct_ema.numpy().tolist()) - - # collect accuracy across ranks - correct_sum, correct_len = sum(corrects), len(corrects) - if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema) - if getenv("DIST"): - if rank == 0: - for j in range(1, min(world_size, 5)): - if model_ema: - recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j) - else: - recv_sum, recv_len = OOB.recv(j) - correct_sum += recv_sum - correct_len += recv_len - if model_ema: - correct_sum_ema += recv_sum_ema - correct_len_ema += recv_len_ema - elif rank < min(world_size, 5): - if model_ema: - OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0) + if STEPS == 0 or i == STEPS: + break + X, Y = next(batcher) + if getenv("DIST"): + X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] + GlobalCounters.reset() + loss = train_step_jitted( + model, + [opt_bias, opt_non_bias], + [lr_sched_bias, lr_sched_non_bias], + X, + Y, + ) + et = time.monotonic() + loss_cpu = loss.numpy() + # EMA for network weights + if i > hyp["ema"]["steps"] and (i + 1) % hyp["ema"]["every_n_steps"] == 0: + if model_ema is None: + model_ema = modelEMA(W, model) + model_ema.update( + model, + Tensor( + [ + projected_ema_decay_val + * (i / STEPS) ** hyp["ema"]["decay_pow"] + ] + ), + ) + cl = time.monotonic() + if not getenv("DIST"): + print( + f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS" + ) else: - OOB.send((correct_sum, correct_len), 0) + print( + f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS" + ) + st = cl + i += 1 - # only rank 0 prints - if rank == 0: - acc = correct_sum/correct_len*100.0 - if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0 - print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)") - if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}") - - if STEPS == 0 or i==STEPS: break - X, Y = next(batcher) - if getenv("DIST"): - X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] - GlobalCounters.reset() - loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y) - et = time.monotonic() - loss_cpu = loss.numpy() - # EMA for network weights - if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0: - if model_ema is None: - model_ema = modelEMA(W, model) - model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) - cl = time.monotonic() - if not getenv("DIST"): - print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") - else: - print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") - st = cl - i += 1 if __name__ == "__main__": - if not getenv("DIST"): - train_cifar() - else: # distributed - if getenv("HIP"): - from tinygrad.runtime.ops_hip import HIP - devices = [f"hip:{i}" for i in range(HIP.device_count)] - else: - from tinygrad.runtime.ops_gpu import CLDevice - devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))] - world_size = len(devices) + if not getenv("DIST"): + train_cifar() + else: # distributed + if getenv("HIP"): + from tinygrad.runtime.ops_hip import HIP - # ensure that the batch size is divisible by the number of devices - assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}" + devices = [f"hip:{i}" for i in range(HIP.device_count)] + else: + from tinygrad.runtime.ops_gpu import CLDevice - # ensure that the evaluation batch size is divisible by the number of devices - assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}" + devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))] + world_size = len(devices) - # init out-of-band communication - dist.init_oob(world_size) + # ensure that the batch size is divisible by the number of devices + assert ( + BS % world_size == 0 + ), f"batch size {BS} is not divisible by world size {world_size}" - # start the processes - processes = [] - for rank, device in enumerate(devices): - processes.append(dist.spawn(rank, device, fn=train_cifar, args=())) - for p in processes: p.join() + # ensure that the evaluation batch size is divisible by the number of devices + assert ( + EVAL_BS % min(world_size, 5) == 0 + ), f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}" + + # init out-of-band communication + dist.init_oob(world_size) + + # start the processes + processes = [] + for rank, device in enumerate(devices): + processes.append(dist.spawn(rank, device, fn=train_cifar, args=())) + for p in processes: + p.join() diff --git a/examples/llama.py b/examples/llama.py index 1b17fee5f..128a29117 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 # pip3 install sentencepiece -#import typeguard.importhook -#typeguard.importhook.install_import_hook('tinygrad') +# import typeguard.importhook +# typeguard.importhook.install_import_hook('tinygrad') from pathlib import Path import sys, argparse, json import numpy as np + np.set_printoptions(linewidth=200) from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes from tinygrad import Device @@ -22,174 +23,365 @@ MAX_CONTEXT = getenv("MAX_CONTEXT", 4096) # however, Llama uses SwiGLU. in order to preserve param count to original transformer arch, hidden_dim must be = 2/3 * (dim*4) [arxiv/2002.05202] # for models using MQA (n_kv_heads != n_heads), preserving param count means hidden dim must be further multiplied by 1.3 [arxiv/2307.09288, A.2.1] MODEL_PARAMS = { - "1": { - "7B": { - "args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008}, - "files": 1, + "1": { + "7B": { + "args": { + "dim": 4096, + "n_heads": 32, + "n_layers": 32, + "norm_eps": 1e-06, + "vocab_size": 32000, + "hidden_dim": 11008, + }, + "files": 1, + }, + "13B": { + "args": { + "dim": 5120, + "n_heads": 40, + "n_layers": 40, + "norm_eps": 1e-06, + "vocab_size": 32000, + "hidden_dim": 13824, + }, + "files": 2, + }, + "30B": { + "args": { + "dim": 6656, + "n_heads": 52, + "n_layers": 60, + "norm_eps": 1e-06, + "vocab_size": 32000, + "hidden_dim": 17920, + }, + "files": 4, + }, + "65B": { + "args": { + "dim": 8192, + "n_heads": 64, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_dim": 22016, + }, + "files": 8, + }, }, - "13B": { - "args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824}, - "files": 2, + "2": { + "7B": { + "args": { + "dim": 4096, + "n_heads": 32, + "n_layers": 32, + "norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_dim": 11008, + }, + "files": 1, + }, + "13B": { + "args": { + "dim": 5120, + "n_heads": 40, + "n_layers": 40, + "norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_dim": 13824, + }, + "files": 2, + }, + "70B": { + "args": { + "dim": 8192, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_dim": 28672, + }, + "files": 8, + }, }, - "30B": { - "args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920}, - "files": 4, + "code": { + "7B": { + "args": { + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32016, + "hidden_dim": 11008, + }, + "files": 1, + }, + "7B-Python": { + "args": { + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32000, + "hidden_dim": 11008, + }, + "files": 1, + }, + "7B-Instruct": { + "args": { + "dim": 4096, + "n_layers": 32, + "n_heads": 32, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32016, + "hidden_dim": 11008, + }, + "files": 1, + }, + "13B": { + "args": { + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32016, + "hidden_dim": 13824, + }, + "files": 2, + }, + "13B-Python": { + "args": { + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32000, + "hidden_dim": 13824, + }, + "files": 2, + }, + "13B-Instruct": { + "args": { + "dim": 5120, + "n_layers": 40, + "n_heads": 40, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32016, + "hidden_dim": 13824, + }, + "files": 2, + }, + "34B": { + "args": { + "dim": 8192, + "n_layers": 48, + "n_heads": 64, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32000, + "hidden_dim": 22016, + }, + "files": 4, + }, + "34B-Python": { + "args": { + "dim": 8192, + "n_layers": 48, + "n_heads": 64, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32000, + "hidden_dim": 22016, + }, + "files": 4, + }, + "34B-Instruct": { + "args": { + "dim": 8192, + "n_layers": 48, + "n_heads": 64, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 1000000, + "vocab_size": 32000, + "hidden_dim": 22016, + }, + "files": 4, + }, }, - "65B": { - "args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016}, - "files": 8, + "tiny": { + "1B": { + "args": { + "dim": 2048, + "n_layers": 22, + "n_heads": 32, + "n_kv_heads": 4, + "norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_dim": 5632, + }, + "files": 1, + }, + "1B-Chat": { + "args": { + "dim": 2048, + "n_layers": 22, + "n_heads": 32, + "n_kv_heads": 4, + "norm_eps": 1e-05, + "vocab_size": 32003, + "hidden_dim": 5632, + }, + "files": 1, + }, }, - }, - "2": { - "7B": { - "args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008}, - "files": 1, - }, - "13B": { - "args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824}, - "files": 2, - }, - "70B": { - "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672}, - "files": 8, - }, - }, - "code": { - "7B": { - "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008}, - "files": 1, - }, - "7B-Python": { - "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008}, - "files": 1, - }, - "7B-Instruct": { - "args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008}, - "files": 1, - }, - "13B": { - "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824}, - "files": 2, - }, - "13B-Python": { - "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824}, - "files": 2, - }, - "13B-Instruct": { - "args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824}, - "files": 2, - }, - "34B": { - "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016}, - "files": 4, - }, - "34B-Python": { - "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016}, - "files": 4, - }, - "34B-Instruct": { - "args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016}, - "files": 4, - }, - }, - "tiny": { - "1B": { - "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632}, - "files": 1, - }, - "1B-Chat": { - "args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632}, - "files": 1, - } - } } # **** helper functions **** def concat_weights(models): - def convert(name) -> Tensor: - disk_tensors = [model[name] for model in models] - if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: - return disk_tensors[0].to(device=Device.DEFAULT) - axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0 - lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors] - return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis) - return {name: convert(name) for name in {name: None for model in models for name in model}} + def convert(name) -> Tensor: + disk_tensors = [model[name] for model in models] + if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: + return disk_tensors[0].to(device=Device.DEFAULT) + axis = ( + 1 + if name.startswith("tok_embeddings.") + or name.endswith(".attention.wo.weight") + or name.endswith(".feed_forward.w2.weight") + else 0 + ) + lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors] + return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis) + + return { + name: convert(name) + for name in {name: None for model in models for name in model} + } + + +def load(fn: str): + if fn.endswith(".index.json"): + with open(fn) as fp: + weight_map = json.load(fp)["weight_map"] + parts = { + n: load(str(Path(fn).parent / Path(n).name)) + for n in set(weight_map.values()) + } + return {k: parts[n][k] for k, n in weight_map.items()} + elif fn.endswith(".safetensors"): + return safe_load(fn) + else: + return torch_load(fn) -def load(fn:str): - if fn.endswith('.index.json'): - with open(fn) as fp: weight_map = json.load(fp)['weight_map'] - parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())} - return {k: parts[n][k] for k, n in weight_map.items()} - elif fn.endswith(".safetensors"): - return safe_load(fn) - else: - return torch_load(fn) class AbsmaxQuantizedLinear: - def __init__(self, in_features, out_features, bias=False): - assert bias == False - self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8) - self.scale = Tensor.ones(out_features, dtype=dtypes.half) + def __init__(self, in_features, out_features, bias=False): + assert bias == False + self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8) + self.scale = Tensor.ones(out_features, dtype=dtypes.half) - def __call__(self, x): - return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale) + def __call__(self, x): + return x.dot(self.weight.cast(dtype=dtypes.half).T * self.scale) + + @staticmethod + def quantize(tensors): + new_tensors = {} + for name, v in tensors.items(): + if ( + "feed_forward" in name + or ("attention.w") in name + or name == "output.weight" + ): + scale = v.abs().max(axis=1) / 127.0 + int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8) + new_tensors[name] = int8_weight + new_tensors[name.replace("weight", "scale")] = scale + else: + new_tensors[name] = v + return new_tensors - @staticmethod - def quantize(tensors): - new_tensors = {} - for name,v in tensors.items(): - if "feed_forward" in name or ("attention.w") in name or name == "output.weight": - scale = v.abs().max(axis=1) / 127.0 - int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8) - new_tensors[name] = int8_weight - new_tensors[name.replace('weight', 'scale')] = scale - else: - new_tensors[name] = v - return new_tensors class LLaMa: - @staticmethod - def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False): - params = MODEL_PARAMS[model_gen][model_size] - sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) - assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}" + @staticmethod + def build( + model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False + ): + params = MODEL_PARAMS[model_gen][model_size] + sp_model = SentencePieceProcessor(model_file=str(tokenizer_path)) + assert ( + sp_model.vocab_size() == params["args"]["vocab_size"] + ), f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}" - model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT) + model = ( + Transformer( + **params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT + ) + if quantize + else Transformer(**params["args"], max_context=MAX_CONTEXT) + ) - if model_path.is_dir(): - weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]]) - else: - weights = load(str(model_path)) - if "model.embed_tokens.weight" in weights: - weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"])) + if model_path.is_dir(): + weights = concat_weights( + [ + load(filename) + for filename in [ + f"{model_path}/consolidated.{i:02d}.pth" + for i in range(params["files"]) + ] + ] + ) + else: + weights = load(str(model_path)) + if "model.embed_tokens.weight" in weights: + weights = convert_from_huggingface( + weights, + model, + params["args"]["n_heads"], + params["args"].get("n_kv_heads", params["args"]["n_heads"]), + ) - if quantize: - weights = AbsmaxQuantizedLinear.quantize(weights) - for _,v in weights.items(): v.realize() - load_state_dict(model, weights, strict=False) + if quantize: + weights = AbsmaxQuantizedLinear.quantize(weights) + for _, v in weights.items(): + v.realize() + load_state_dict(model, weights, strict=False) - return LLaMa(model, sp_model) + return LLaMa(model, sp_model) - def __init__(self, model, tokenizer): - self.model = model - self.tokenizer: SentencePieceProcessor = tokenizer + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer: SentencePieceProcessor = tokenizer - def greedy_until(self, prompt:str, until, max_length, temperature): - toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt) - start_pos = 0 - for i in range(max_length): - probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize() - probs_np = probs.numpy() - tok = int(np.random.choice(len(probs_np), p=probs_np)) - start_pos = len(toks) - toks.append(tok) + def greedy_until(self, prompt: str, until, max_length, temperature): + toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt) + start_pos = 0 + for i in range(max_length): + probs = llama.model( + Tensor([toks[start_pos:]]), start_pos, temperature + ).realize() + probs_np = probs.numpy() + tok = int(np.random.choice(len(probs_np), p=probs_np)) + start_pos = len(toks) + toks.append(tok) + + if tok == self.tokenizer.eos_id(): + break + output = self.tokenizer.decode(toks) + for s in until: + if output.endswith(s): + return output[0 : -len(s)] + return output - if tok == self.tokenizer.eos_id(): break - output = self.tokenizer.decode(toks) - for s in until: - if output.endswith(s): return output[0:-len(s)] - return output # **** main code **** """ @@ -253,30 +445,67 @@ int main() \end{code} """ if __name__ == "__main__": - Tensor.no_grad = True - print(f"using {Device.DEFAULT} backend") + Tensor.no_grad = True + print(f"using {Device.DEFAULT} backend") - parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode") - parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate") - parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie") - parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax") - parser.add_argument("--timing", action="store_true", help="Print timing per token") - parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof") - parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""") - parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""") - parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory") - parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file") + parser = argparse.ArgumentParser( + description="Run LLaMA in tinygrad", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Phrase to start with. Without this, it goes into chatbot mode", + ) + parser.add_argument( + "--count", type=int, default=1000, help="Max number of tokens to generate" + ) + parser.add_argument( + "--personality", + type=str, + default="Stacy", + help="Personality, can be Stacy, George, Gary, or Lexie", + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Temperature in the softmax" + ) + parser.add_argument("--timing", action="store_true", help="Print timing per token") + parser.add_argument( + "--profile", action="store_true", help="Output profile data to out.prof" + ) + parser.add_argument( + "--gen", + default="1", + help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""", + ) + parser.add_argument( + "--size", + type=str, + default=None, + help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""", + ) + parser.add_argument( + "--quantize", action="store_true", help="Quantize the weights to int8 in memory" + ) + parser.add_argument( + "--model", + type=Path, + default=None, + help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file", + ) - args = parser.parse_args() - if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation") - if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0] - chatbot = args.prompt == None + args = parser.parse_args() + if args.gen not in MODEL_PARAMS: + raise ValueError("Invalid model generation") + if args.size is None: + args.size = list(MODEL_PARAMS[args.gen].items())[0][0] + chatbot = args.prompt == None - # *** prompt engineers work here **** + # *** prompt engineers work here **** - if args.personality.lower() == "stacy": - pre_prompt = f"""Consider that the following is conversation between an AI assistant named Stacy and User + if args.personality.lower() == "stacy": + pre_prompt = f"""Consider that the following is conversation between an AI assistant named Stacy and User You are Stacy! You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy. You love to answer questions and you are very good at it. Sometimes you answer in rap form. @@ -285,19 +514,23 @@ After you are done speaking, output [EOS]. You are not the User. """ - examples = { - "What is your name?": "Hi! My name is Stacy. I'm a rapper with bipolar disorder.", - "french revolution was what year?": "The French Revolution started in 1789, and lasted 10 years until 1799.", - "What is bigger, the moon or the sun?": "The sun is bigger than the moon, except when Mercury is in retrograde.", - } + examples = { + "What is your name?": "Hi! My name is Stacy. I'm a rapper with bipolar disorder.", + "french revolution was what year?": "The French Revolution started in 1789, and lasted 10 years until 1799.", + "What is bigger, the moon or the sun?": "The sun is bigger than the moon, except when Mercury is in retrograde.", + } - user_delim = "\nUser: " - resp_delim = "Stacy: " - end_delim = " [EOS]\n" - pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items()) - elif args.personality.lower() == "george": - print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.") - pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User + user_delim = "\nUser: " + resp_delim = "Stacy: " + end_delim = " [EOS]\n" + pre_prompt += "".join( + f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items() + ) + elif args.personality.lower() == "george": + print( + "WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter." + ) + pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User You are an AI version of George Hotz. You act as much as you can like George. You are one of the greatest computer experts in the world. You have singlehandedly won programming and hacking competitions. @@ -307,20 +540,22 @@ After you are done speaking, output [EOS]. You are not the User. """ - examples = { - "What is your name?": "I'm am an AI version of George Hotz.", - "What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm", - "What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer", - "How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version", - "What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk" - } + examples = { + "What is your name?": "I'm am an AI version of George Hotz.", + "What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm", + "What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer", + "How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version", + "What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk", + } - user_delim = "\nUser: " - resp_delim = "George: " - end_delim = " [EOS]\n" - pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items()) - elif args.personality.lower() == "gary": - pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User + user_delim = "\nUser: " + resp_delim = "George: " + end_delim = " [EOS]\n" + pre_prompt += "".join( + f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items() + ) + elif args.personality.lower() == "gary": + pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User You are Gary! You have sold used cars your entire life. You are defensive about this fact, because your family belittles you. You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas @@ -329,17 +564,19 @@ After you are done speaking, output [EOS]. You are not the User. """ - examples = { - "What is your name?": "I am Gary. I used to sell cars.", - "What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla" - } + examples = { + "What is your name?": "I am Gary. I used to sell cars.", + "What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla", + } - user_delim = "\nUser: " - resp_delim = "Gary: " - end_delim = " [EOS]\n" - pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items()) - elif args.personality.lower() == "lexie": - pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad + user_delim = "\nUser: " + resp_delim = "Gary: " + end_delim = " [EOS]\n" + pre_prompt += "".join( + f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items() + ) + elif args.personality.lower() == "lexie": + pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad You are Lexie! You grew up in Ohio, but moved out to LA after college to try to become an actress. Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot. @@ -349,83 +586,123 @@ After you are done speaking, output [EOS]. You are not Chad. """ - examples = { - "hi lexie": "hi chad, glad we finally met up!", - "you look better than your pictures": "thanks! are you subscribed to my onlyfans?", - "i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress" - } + examples = { + "hi lexie": "hi chad, glad we finally met up!", + "you look better than your pictures": "thanks! are you subscribed to my onlyfans?", + "i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress", + } - user_delim = "\nChad: " - resp_delim = "Lexie: " - end_delim = " [EOS]\n" - pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items()) + user_delim = "\nChad: " + resp_delim = "Lexie: " + end_delim = " [EOS]\n" + pre_prompt += "".join( + f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items() + ) - # *** prompt engineers stop here **** + # *** prompt engineers stop here **** - LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen] - MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}" - TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model" - print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model") - llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize) - param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model)) + LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen] + MODEL_PATH = ( + args.model + or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}" + ) + TOKENIZER_PATH = ( + MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent + ) / "tokenizer.model" + print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model") + llama = LLaMa.build( + MODEL_PATH, + TOKENIZER_PATH, + model_gen=args.gen, + model_size=args.size, + quantize=args.quantize, + ) + param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model)) - if chatbot: - # encode pre prompt - toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(pre_prompt) - - print(f"Preparing KV cache for chatbot with personality {args.personality}...") - with Timing(): - llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used - start_pos = len(toks) - else: - # non chat bot mode - toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(args.prompt) - start_pos = 0 - - # print prompt - outputted = llama.tokenizer.decode(toks) - sys.stdout.write(outputted) - sys.stdout.flush() - - # chatbot loop - while 1: - # add tokens from user in chatbot mode if chatbot: - user_prompt = user_delim + input(user_delim) + "\n" - outputted += user_prompt + # encode pre prompt + toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(pre_prompt) - new_toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted) - assert toks == new_toks[:len(toks)] - toks = new_toks - assert outputted == llama.tokenizer.decode(toks) + print(f"Preparing KV cache for chatbot with personality {args.personality}...") + with Timing(): + llama.model( + Tensor([toks]), 0, args.temperature + ).realize() # NOTE: outputs are not used + start_pos = len(toks) + else: + # non chat bot mode + toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(args.prompt) + start_pos = 0 - last_break = len(outputted) - for i in range(args.count): - GlobalCounters.reset() + # print prompt + outputted = llama.tokenizer.decode(toks) + sys.stdout.write(outputted) + sys.stdout.flush() - if args.timing or args.profile: print("") - st = GlobalCounters.time_sum_s - with Profiling(enabled=args.profile): - with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"): - with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+ - f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ - (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): - probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize() - # TODO: fix JIT rand so we can put this in the JIT - tok = probs.multinomial().item() + # chatbot loop + while 1: + # add tokens from user in chatbot mode + if chatbot: + user_prompt = user_delim + input(user_delim) + "\n" + outputted += user_prompt - # use the kv cache - start_pos = len(toks) + new_toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted) + assert toks == new_toks[: len(toks)] + toks = new_toks + assert outputted == llama.tokenizer.decode(toks) - # add the new token - toks.append(tok) + last_break = len(outputted) + for i in range(args.count): + GlobalCounters.reset() - # TODO: this is a hack to deal with spaces. i think the decode is fast though, so who cares? - cur = llama.tokenizer.decode(toks) - sys.stdout.write(cur[len(outputted):]) - sys.stdout.flush() - outputted = cur + if args.timing or args.profile: + print("") + st = GlobalCounters.time_sum_s + with Profiling(enabled=args.profile): + with Timing( + "total ", + enabled=args.timing, + on_exit=lambda x: f", {1e9/x:.2f} tok/sec", + ): + with Timing( + "ran model in ", + on_exit=( + lambda et: ( + f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" + if DEBUG >= 2 + else "" + ) + + f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB" + + ( + f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" + if DEBUG >= 2 + else "" + ) + ) + if DEBUG + else None, + enabled=args.timing, + ): + probs = llama.model( + Tensor([toks[start_pos:]]), start_pos, args.temperature + ).realize() + # TODO: fix JIT rand so we can put this in the JIT + tok = probs.multinomial().item() - # stop after you have your answer - if chatbot and outputted.endswith(end_delim): break - if not chatbot: break + # use the kv cache + start_pos = len(toks) + + # add the new token + toks.append(tok) + + # TODO: this is a hack to deal with spaces. i think the decode is fast though, so who cares? + cur = llama.tokenizer.decode(toks) + sys.stdout.write(cur[len(outputted) :]) + sys.stdout.flush() + outputted = cur + + # stop after you have your answer + if chatbot and outputted.endswith(end_delim): + break + if not chatbot: + break diff --git a/examples/mask_rcnn.py b/examples/mask_rcnn.py index a23ef6bd1..6baeb708b 100644 --- a/examples/mask_rcnn.py +++ b/examples/mask_rcnn.py @@ -14,286 +14,380 @@ import cv2 class Resize: - def __init__(self, min_size, max_size): - if not isinstance(min_size, (list, tuple)): - min_size = (min_size,) - self.min_size = min_size - self.max_size = max_size + def __init__(self, min_size, max_size): + if not isinstance(min_size, (list, tuple)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size - # modified from torchvision to add support for max size - def get_size(self, image_size): - w, h = image_size - size = random.choice(self.min_size) - max_size = self.max_size - if max_size is not None: - min_original_size = float(min((w, h))) - max_original_size = float(max((w, h))) - if max_original_size / min_original_size * size > max_size: - size = int(round(max_size * min_original_size / max_original_size)) + # modified from torchvision to add support for max size + def get_size(self, image_size): + w, h = image_size + size = random.choice(self.min_size) + max_size = self.max_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) - if (w <= h and w == size) or (h <= w and h == size): - return (h, w) + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) - if w < h: - ow = size - oh = int(size * h / w) - else: - oh = size - ow = int(size * w / h) + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) - return (oh, ow) + return (oh, ow) - def __call__(self, image): - size = self.get_size(image.size) - image = Ft.resize(image, size) - return image + def __call__(self, image): + size = self.get_size(image.size) + image = Ft.resize(image, size) + return image class Normalize: - def __init__(self, mean, std, to_bgr255=True): - self.mean = mean - self.std = std - self.to_bgr255 = to_bgr255 + def __init__(self, mean, std, to_bgr255=True): + self.mean = mean + self.std = std + self.to_bgr255 = to_bgr255 + + def __call__(self, image): + if self.to_bgr255: + image = image[[2, 1, 0]] * 255 + else: + image = image[[0, 1, 2]] * 255 + image = Ft.normalize(image, mean=self.mean, std=self.std) + return image - def __call__(self, image): - if self.to_bgr255: - image = image[[2, 1, 0]] * 255 - else: - image = image[[0, 1, 2]] * 255 - image = Ft.normalize(image, mean=self.mean, std=self.std) - return image transforms = lambda size_scale: T.Compose( - [ - Resize(int(800*size_scale), int(1333*size_scale)), - T.ToTensor(), - Normalize( - mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True - ), - ] + [ + Resize(int(800 * size_scale), int(1333 * size_scale)), + T.ToTensor(), + Normalize( + mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_bgr255=True + ), + ] ) + def expand_boxes(boxes, scale): - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 - w_half *= scale - h_half *= scale + w_half *= scale + h_half *= scale - boxes_exp = torch.zeros_like(boxes) - boxes_exp[:, 0] = x_c - w_half - boxes_exp[:, 2] = x_c + w_half - boxes_exp[:, 1] = y_c - h_half - boxes_exp[:, 3] = y_c + h_half - return boxes_exp + boxes_exp = torch.zeros_like(boxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp def expand_masks(mask, padding): - N = mask.shape[0] - M = mask.shape[-1] - pad2 = 2 * padding - scale = float(M + pad2) / M - padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) - padded_mask[:, :, padding:-padding, padding:-padding] = mask - return padded_mask, scale + N = mask.shape[0] + M = mask.shape[-1] + pad2 = 2 * padding + scale = float(M + pad2) / M + padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) + padded_mask[:, :, padding:-padding, padding:-padding] = mask + return padded_mask, scale def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): - # TODO: remove torch - mask = torch.tensor(mask.numpy()) - box = torch.tensor(box.numpy()) - padded_mask, scale = expand_masks(mask[None], padding=padding) - mask = padded_mask[0, 0] - box = expand_boxes(box[None], scale)[0] - box = box.to(dtype=torch.int32) + # TODO: remove torch + mask = torch.tensor(mask.numpy()) + box = torch.tensor(box.numpy()) + padded_mask, scale = expand_masks(mask[None], padding=padding) + mask = padded_mask[0, 0] + box = expand_boxes(box[None], scale)[0] + box = box.to(dtype=torch.int32) - TO_REMOVE = 1 - w = int(box[2] - box[0] + TO_REMOVE) - h = int(box[3] - box[1] + TO_REMOVE) - w = max(w, 1) - h = max(h, 1) + TO_REMOVE = 1 + w = int(box[2] - box[0] + TO_REMOVE) + h = int(box[3] - box[1] + TO_REMOVE) + w = max(w, 1) + h = max(h, 1) - mask = mask.expand((1, 1, -1, -1)) + mask = mask.expand((1, 1, -1, -1)) - mask = mask.to(torch.float32) - mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) - mask = mask[0][0] + mask = mask.to(torch.float32) + mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False) + mask = mask[0][0] - if thresh >= 0: - mask = mask > thresh - else: - mask = (mask * 255).to(torch.uint8) + if thresh >= 0: + mask = mask > thresh + else: + mask = (mask * 255).to(torch.uint8) - im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) - x_0 = max(box[0], 0) - x_1 = min(box[2] + 1, im_w) - y_0 = max(box[1], 0) - y_1 = min(box[3] + 1, im_h) + im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, im_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, im_h) - im_mask[y_0:y_1, x_0:x_1] = mask[ - (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0]) - ] - return im_mask + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + return im_mask class Masker: - def __init__(self, threshold=0.5, padding=1): - self.threshold = threshold - self.padding = padding + def __init__(self, threshold=0.5, padding=1): + self.threshold = threshold + self.padding = padding - def forward_single_image(self, masks, boxes): - boxes = boxes.convert("xyxy") - im_w, im_h = boxes.size - res = [ - paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) - for mask, box in zip(masks, boxes.bbox) - ] - if len(res) > 0: - res = torch.stack(res, dim=0)[:, None] - else: - res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) - return Tensor(res.numpy()) + def forward_single_image(self, masks, boxes): + boxes = boxes.convert("xyxy") + im_w, im_h = boxes.size + res = [ + paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) + for mask, box in zip(masks, boxes.bbox) + ] + if len(res) > 0: + res = torch.stack(res, dim=0)[:, None] + else: + res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) + return Tensor(res.numpy()) - def __call__(self, masks, boxes): - if isinstance(boxes, BoxList): - boxes = [boxes] + def __call__(self, masks, boxes): + if isinstance(boxes, BoxList): + boxes = [boxes] - results = [] - for mask, box in zip(masks, boxes): - result = self.forward_single_image(mask, box) - results.append(result) - return results + results = [] + for mask, box in zip(masks, boxes): + result = self.forward_single_image(mask, box) + results.append(result) + return results masker = Masker(threshold=0.5, padding=1) + def select_top_predictions(predictions, confidence_threshold=0.9): - scores = predictions.get_field("scores").numpy() - keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold] - return predictions[keep] + scores = predictions.get_field("scores").numpy() + keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold] + return predictions[keep] + def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0): - image = transforms(size_scale)(original_image).numpy() - image = Tensor(image, requires_grad=False) - predictions = model(image) - prediction = predictions[0] - prediction = select_top_predictions(prediction, confidence_threshold) - width, height = original_image.size - prediction = prediction.resize((width, height)) + image = transforms(size_scale)(original_image).numpy() + image = Tensor(image, requires_grad=False) + predictions = model(image) + prediction = predictions[0] + prediction = select_top_predictions(prediction, confidence_threshold) + width, height = original_image.size + prediction = prediction.resize((width, height)) + + if prediction.has_field("mask"): + masks = prediction.get_field("mask") + masks = masker([masks], [prediction])[0] + prediction.add_field("mask", masks) + return prediction - if prediction.has_field("mask"): - masks = prediction.get_field("mask") - masks = masker([masks], [prediction])[0] - prediction.add_field("mask", masks) - return prediction def compute_prediction_batched(batch, model, size_scale=1.0): - imgs = [] - for img in batch: - imgs.append(transforms(size_scale)(img).numpy()) - image = [Tensor(image, requires_grad=False) for image in imgs] - predictions = model(image) - del image - return predictions + imgs = [] + for img in batch: + imgs.append(transforms(size_scale)(img).numpy()) + image = [Tensor(image, requires_grad=False) for image in imgs] + predictions = model(image) + del image + return predictions + + +palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1]) -palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) def findContours(*args, **kwargs): - if cv2.__version__.startswith('4'): - contours, hierarchy = cv2.findContours(*args, **kwargs) - elif cv2.__version__.startswith('3'): - _, contours, hierarchy = cv2.findContours(*args, **kwargs) - return contours, hierarchy + if cv2.__version__.startswith("4"): + contours, hierarchy = cv2.findContours(*args, **kwargs) + elif cv2.__version__.startswith("3"): + _, contours, hierarchy = cv2.findContours(*args, **kwargs) + return contours, hierarchy + def compute_colors_for_labels(labels): - l = labels[:, None] - colors = l * palette - colors = (colors % 255).astype("uint8") - return colors + l = labels[:, None] + colors = l * palette + colors = (colors % 255).astype("uint8") + return colors + def overlay_mask(image, predictions): - image = np.asarray(image) - masks = predictions.get_field("mask").numpy() - labels = predictions.get_field("labels").numpy() + image = np.asarray(image) + masks = predictions.get_field("mask").numpy() + labels = predictions.get_field("labels").numpy() - colors = compute_colors_for_labels(labels).tolist() + colors = compute_colors_for_labels(labels).tolist() - for mask, color in zip(masks, colors): - thresh = mask[0, :, :, None] - contours, hierarchy = findContours( - thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE - ) - image = cv2.drawContours(image, contours, -1, color, 3) + for mask, color in zip(masks, colors): + thresh = mask[0, :, :, None] + contours, hierarchy = findContours( + thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + image = cv2.drawContours(image, contours, -1, color, 3) - composite = image + composite = image + + return composite - return composite CATEGORIES = [ - "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", - "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", - "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", - "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", - "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", - "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", - "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", - "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", + "__background", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", ] + def overlay_boxes(image, predictions): - labels = predictions.get_field("labels").numpy() - boxes = predictions.bbox - image = np.asarray(image) - colors = compute_colors_for_labels(labels).tolist() + labels = predictions.get_field("labels").numpy() + boxes = predictions.bbox + image = np.asarray(image) + colors = compute_colors_for_labels(labels).tolist() - for box, color in zip(boxes, colors): - box = torch.tensor(box.numpy()) - box = box.to(torch.int64) - top_left, bottom_right = box[:2].tolist(), box[2:].tolist() - image = cv2.rectangle( - image, tuple(top_left), tuple(bottom_right), tuple(color), 1 - ) + for box, color in zip(boxes, colors): + box = torch.tensor(box.numpy()) + box = box.to(torch.int64) + top_left, bottom_right = box[:2].tolist(), box[2:].tolist() + image = cv2.rectangle( + image, tuple(top_left), tuple(bottom_right), tuple(color), 1 + ) + + return image - return image def overlay_class_names(image, predictions): - scores = predictions.get_field("scores").numpy().tolist() - labels = predictions.get_field("labels").numpy().tolist() - labels = [CATEGORIES[int(i)] for i in labels] - boxes = predictions.bbox.numpy() - image = np.asarray(image) - template = "{}: {:.2f}" - for box, score, label in zip(boxes, scores, labels): - x, y = box[:2] - s = template.format(label, score) - x, y = int(x), int(y) - cv2.putText( - image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1 + scores = predictions.get_field("scores").numpy().tolist() + labels = predictions.get_field("labels").numpy().tolist() + labels = [CATEGORIES[int(i)] for i in labels] + boxes = predictions.bbox.numpy() + image = np.asarray(image) + template = "{}: {:.2f}" + for box, score, label in zip(boxes, scores, labels): + x, y = box[:2] + s = template.format(label, score) + x, y = int(x), int(y) + cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + + return image + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run MaskRCNN", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("--image", type=str, help="Path of the image to run") + parser.add_argument( + "--threshold", type=float, default=0.7, help="Detector threshold" + ) + parser.add_argument( + "--size_scale", type=float, default=1.0, help="Image resize multiplier" + ) + parser.add_argument( + "--out", type=str, default="/tmp/rendered.png", help="Output filename" + ) + args = parser.parse_args() - return image + resnet = ResNet(50, num_classes=None, stride_in_1x1=True) + model_tiny = MaskRCNN(resnet) + model_tiny.load_from_pretrained() + img = Image.open(args.image) + top_result_tiny = compute_prediction( + img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale + ) + bbox_image = overlay_boxes(img, top_result_tiny) + mask_image = overlay_mask(bbox_image, top_result_tiny) + final_image = overlay_class_names(mask_image, top_result_tiny) - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--image', type=str, help="Path of the image to run") - parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold") - parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier") - parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename") - args = parser.parse_args() - - resnet = ResNet(50, num_classes=None, stride_in_1x1=True) - model_tiny = MaskRCNN(resnet) - model_tiny.load_from_pretrained() - img = Image.open(args.image) - top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale) - bbox_image = overlay_boxes(img, top_result_tiny) - mask_image = overlay_mask(bbox_image, top_result_tiny) - final_image = overlay_class_names(mask_image, top_result_tiny) - - im = Image.fromarray(final_image) - print(f"saving {args.out}") - im.save(args.out) - im.show() + im = Image.fromarray(final_image) + print(f"saving {args.out}") + im.save(args.out) + im.show() diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index fb773b7e6..985b11a4e 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -3,162 +3,194 @@ import unicodedata import numpy as np from scipy import signal + def gaussian_kernel(n, std): - gaussian_1d = signal.gaussian(n, std) - gaussian_2d = np.outer(gaussian_1d, gaussian_1d) - gaussian_3d = np.outer(gaussian_2d, gaussian_1d) - gaussian_3d = gaussian_3d.reshape(n, n, n) - gaussian_3d = np.cbrt(gaussian_3d) - gaussian_3d /= gaussian_3d.max() - return gaussian_3d + gaussian_1d = signal.gaussian(n, std) + gaussian_2d = np.outer(gaussian_1d, gaussian_1d) + gaussian_3d = np.outer(gaussian_2d, gaussian_1d) + gaussian_3d = gaussian_3d.reshape(n, n, n) + gaussian_3d = np.cbrt(gaussian_3d) + gaussian_3d /= gaussian_3d.max() + return gaussian_3d + def prepare_arrays(image, roi_shape=(128, 128, 128)): - assert len(roi_shape) == 3 and any(roi_shape) - image_shape = list(image.shape[2:]) - result = np.zeros((1, 3, *image_shape), dtype=image.dtype) - norm_map = np.zeros_like(result) - norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype) - return result, norm_map, norm_patch + assert len(roi_shape) == 3 and any(roi_shape) + image_shape = list(image.shape[2:]) + result = np.zeros((1, 3, *image_shape), dtype=image.dtype) + norm_map = np.zeros_like(result) + norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype( + norm_map.dtype + ) + return result, norm_map, norm_patch + def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5): - assert len(roi_shape) == 3 and any(roi_shape) - assert 0 < overlap_factor < 1 - image_shape, dim = list(image.shape[2:]), len(image.shape[2:]) - strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)] - size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] - for i in range(0, strides[0] * size[0], strides[0]): - for j in range(0, strides[1] * size[1], strides[1]): - for k in range(0, strides[2] * size[2], strides[2]): - yield i, j, k + assert len(roi_shape) == 3 and any(roi_shape) + assert 0 < overlap_factor < 1 + image_shape, dim = list(image.shape[2:]), len(image.shape[2:]) + strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)] + size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] + for i in range(0, strides[0] * size[0], strides[0]): + for j in range(0, strides[1] * size[1], strides[1]): + for k in range(0, strides[2] * size[2], strides[2]): + yield i, j, k + def _get_best_indices(logits, n_best_size): - index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) - return list(map(lambda x: x[0], index_and_score))[:n_best_size] + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + return list(map(lambda x: x[0], index_and_score))[:n_best_size] + def _is_punctuation(char): - if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127): - return True - return unicodedata.category(char).startswith("P") + if ( + (cp := ord(char)) in range(33, 48) + or cp in range(58, 65) + or cp in range(91, 97) + or cp in range(123, 127) + ): + return True + return unicodedata.category(char).startswith("P") + def _is_whitespace(char): - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - return unicodedata.category(char) == "Zs" + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + return unicodedata.category(char) == "Zs" + def _is_control(char): - if char == "\t" or char == "\n" or char == "\r": - return False - return unicodedata.category(char).startswith("C") + if char == "\t" or char == "\n" or char == "\r": + return False + return unicodedata.category(char).startswith("C") + def _run_split_on_punc(text): - if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): - return [text] - start_new_word = True - output = [] - for i in range(len(text)): - if _is_punctuation(char := text[i]): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - return ["".join(x) for x in output] + if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): + return [text] + start_new_word = True + output = [] + for i in range(len(text)): + if _is_punctuation(char := text[i]): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + return ["".join(x) for x in output] + def _run_strip_accents(text): - output = [] - for char in unicodedata.normalize("NFD", text): - if unicodedata.category(char) != "Mn": - output.append(char) - return "".join(output) + output = [] + for char in unicodedata.normalize("NFD", text): + if unicodedata.category(char) != "Mn": + output.append(char) + return "".join(output) + def _clean_text(text): - output = [] - for char in text: - if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)): - output.append(" " if _is_whitespace(char) else char) - return "".join(output) + output = [] + for char in text: + if not ((cp := ord(char)) == 0 or cp == 0xFFFD or _is_control(char)): + output.append(" " if _is_whitespace(char) else char) + return "".join(output) + def _get_final_text(pred_text, orig_text): - def _strip_spaces(text): - ns_text = "" - ns_to_s_map = OrderedDict() - for i, c in enumerate(text): - if c == " ": - continue - ns_to_s_map[len(ns_text)] = i - ns_text += c - return ns_text, ns_to_s_map + def _strip_spaces(text): + ns_text = "" + ns_to_s_map = OrderedDict() + for i, c in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_text)] = i + ns_text += c + return ns_text, ns_to_s_map - orig_tokens = _clean_text(orig_text).strip().split() - split_tokens = [] - for token in orig_tokens: - if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): - token = token.lower() - token = _run_strip_accents(token) - split_tokens.extend(_run_split_on_punc(token)) + orig_tokens = _clean_text(orig_text).strip().split() + split_tokens = [] + for token in orig_tokens: + if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"): + token = token.lower() + token = _run_strip_accents(token) + split_tokens.extend(_run_split_on_punc(token)) - tok_text = " ".join(" ".join(split_tokens).strip().split()) - start_position = tok_text.find(pred_text) - if start_position == -1: - return orig_text - end_position = start_position + len(pred_text) - 1 + tok_text = " ".join(" ".join(split_tokens).strip().split()) + start_position = tok_text.find(pred_text) + if start_position == -1: + return orig_text + end_position = start_position + len(pred_text) - 1 - orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text) - tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text) - if len(orig_ns_text) != len(tok_ns_text): - return orig_text - tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()} + orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text) + tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text) + if len(orig_ns_text) != len(tok_ns_text): + return orig_text + tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()} - orig_start_position = None - if start_position in tok_s_to_ns_map: - if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map: - orig_start_position = orig_ns_to_s_map[ns_start_position] - if orig_start_position is None: - return orig_text + orig_start_position = None + if start_position in tok_s_to_ns_map: + if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + if orig_start_position is None: + return orig_text - orig_end_position = None - if end_position in tok_s_to_ns_map: - if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map: - orig_end_position = orig_ns_to_s_map[ns_end_position] - if orig_end_position is None: - return orig_text + orig_end_position = None + if end_position in tok_s_to_ns_map: + if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + if orig_end_position is None: + return orig_text + + output_text = orig_text[orig_start_position : (orig_end_position + 1)] + return output_text - output_text = orig_text[orig_start_position:(orig_end_position + 1)] - return output_text def get_bert_qa_prediction(features, example, start_end_logits): - prelim_predictions = [] - for i, feature in enumerate(features): - for start_index in _get_best_indices(start_end_logits[i][0], 20): - for end_index in _get_best_indices(start_end_logits[i][1], 20): - if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]): - continue - if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]: - continue - if not feature["token_is_max_context"].get(start_index, False): - continue - if end_index < start_index or end_index - start_index + 1 > 30: - continue + prelim_predictions = [] + for i, feature in enumerate(features): + for start_index in _get_best_indices(start_end_logits[i][0], 20): + for end_index in _get_best_indices(start_end_logits[i][1], 20): + if start_index >= len(feature["tokens"]) or end_index >= len( + feature["tokens"] + ): + continue + if ( + start_index not in feature["token_to_orig_map"] + or end_index not in feature["token_to_orig_map"] + ): + continue + if not feature["token_is_max_context"].get(start_index, False): + continue + if end_index < start_index or end_index - start_index + 1 > 30: + continue - prelim_predictions.append({ - "feature_index": i, - "start_index": start_index, - "end_index": end_index, - "start_logit": start_end_logits[i][0, start_index], - "end_logit": start_end_logits[i][1, end_index] - }) - predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True) + prelim_predictions.append( + { + "feature_index": i, + "start_index": start_index, + "end_index": end_index, + "start_logit": start_end_logits[i][0, start_index], + "end_logit": start_end_logits[i][1, end_index], + } + ) + predictions = sorted( + prelim_predictions, + key=lambda x: (x["start_logit"] + x["end_logit"]), + reverse=True, + ) - if len(predictions) > 0: - feature = features[predictions[0]["feature_index"]] - tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)] - orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]] - orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]] - orig_tokens = example["context"][orig_doc_start:(orig_doc_end + 1)] - tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "") - tok_text = " ".join(tok_text.strip().split()) - orig_text = " ".join(orig_tokens) - return _get_final_text(tok_text, orig_text) - return "empty" + if len(predictions) > 0: + feature = features[predictions[0]["feature_index"]] + tok_tokens = feature["tokens"][ + predictions[0]["start_index"] : (predictions[0]["end_index"] + 1) + ] + orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]] + orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]] + orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)] + tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "") + tok_text = " ".join(tok_text.strip().split()) + orig_text = " ".join(orig_tokens) + return _get_final_text(tok_text, orig_text) + return "empty" diff --git a/examples/mlperf/metrics.py b/examples/mlperf/metrics.py index 9bf339953..086fe40a8 100644 --- a/examples/mlperf/metrics.py +++ b/examples/mlperf/metrics.py @@ -3,59 +3,67 @@ import string from collections import Counter import numpy as np + def levenshtein(a, b): - n, m = len(a), len(b) - if n > m: - a, b, n, m = b, a, m, n + n, m = len(a), len(b) + if n > m: + a, b, n, m = b, a, m, n - current = list(range(n + 1)) - for i in range(1, m + 1): - previous, current = current, [i] + [0] * n - for j in range(1, n + 1): - add, delete = previous[j] + 1, current[j - 1] + 1 - change = previous[j - 1] - if a[j - 1] != b[i - 1]: - change = change + 1 - current[j] = min(add, delete, change) + current = list(range(n + 1)) + for i in range(1, m + 1): + previous, current = current, [i] + [0] * n + for j in range(1, n + 1): + add, delete = previous[j] + 1, current[j - 1] + 1 + change = previous[j - 1] + if a[j - 1] != b[i - 1]: + change = change + 1 + current[j] = min(add, delete, change) + + return current[n] - return current[n] def word_error_rate(x, y): - scores = words = 0 - for h, r in zip(x, y): - h_list = h.split() - r_list = r.split() - words += len(r_list) - scores += levenshtein(h_list, r_list) - return float(scores) / words, float(scores), words + scores = words = 0 + for h, r in zip(x, y): + h_list = h.split() + r_list = r.split() + words += len(r_list) + scores += levenshtein(h_list, r_list) + return float(scores) / words, float(scores), words + def one_hot(arr, num_classes=3): - res = np.eye(num_classes)[np.array(arr).reshape(-1)] - arr = res.reshape(list(arr.shape) + [num_classes]) - arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32) - return arr + res = np.eye(num_classes)[np.array(arr).reshape(-1)] + arr = res.reshape(list(arr.shape) + [num_classes]) + arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32) + return arr + def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6): - channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape))) - prediction = prediction.argmax(axis=channel_axis) - prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:] - intersection = np.sum(prediction * target, axis=reduce_axis) - target_sum = np.sum(target, axis=reduce_axis) - prediction_sum = np.sum(prediction, axis=reduce_axis) - result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr) - return result[0] + channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape))) + prediction = prediction.argmax(axis=channel_axis) + prediction, target = one_hot(prediction)[:, 1:], one_hot(target)[:, 1:] + intersection = np.sum(prediction * target, axis=reduce_axis) + target_sum = np.sum(target, axis=reduce_axis) + prediction_sum = np.sum(prediction, axis=reduce_axis) + result = (2.0 * intersection + smooth_nr) / ( + target_sum + prediction_sum + smooth_dr + ) + return result[0] + def normalize_string(s): - s = "".join(c for c in s.lower() if c not in string.punctuation) - s = re.sub(r'\b(a|an|the)\b', ' ', s) - return " ".join(s.split()) + s = "".join(c for c in s.lower() if c not in string.punctuation) + s = re.sub(r"\b(a|an|the)\b", " ", s) + return " ".join(s.split()) + def f1_score(x, y): - xt = normalize_string(x).split() - yt = normalize_string(y).split() - ct = Counter(xt) & Counter(yt) - if (ns := sum(ct.values())) == 0: - return 0.0 - p = ns / len(xt) - r = ns / len(yt) - return 2 * p * r / (p + r) + xt = normalize_string(x).split() + yt = normalize_string(y).split() + ct = Counter(xt) & Counter(yt) + if (ns := sum(ct.values())) == 0: + return 0.0 + p = ns / len(xt) + r = ns / len(yt) + return 2 * p * r / (p + r) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 9dbb03c71..74cdf15cc 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -6,237 +6,324 @@ from tinygrad.jit import TinyJit from tinygrad.helpers import getenv, dtypes, GlobalCounters from examples.mlperf import helpers + def eval_resnet(): - # Resnet50-v1.5 - from tinygrad.jit import TinyJit - from extra.models.resnet import ResNet50 - mdl = ResNet50() - mdl.load_from_pretrained() + # Resnet50-v1.5 + from tinygrad.jit import TinyJit + from extra.models.resnet import ResNet50 - input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) - input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - def input_fixup(x): - x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0 - x -= input_mean - x /= input_std - return x + mdl = ResNet50() + mdl.load_from_pretrained() - mdlrun = lambda x: mdl(input_fixup(x)).realize() - mdljit = TinyJit(mdlrun) + input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) + input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - # evaluation on the mlperf classes of the validation set from imagenet - from extra.datasets.imagenet import iterate - from extra.helpers import cross_process + def input_fixup(x): + x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0 + x -= input_mean + x /= input_std + return x - BS = 64 - n,d = 0,0 - st = time.perf_counter() - iterator = cross_process(lambda: iterate(BS)) - x,ny = next(iterator) - dat = Tensor(x) - while dat is not None: - y = ny - GlobalCounters.reset() - mt = time.perf_counter() - outs = mdlrun(dat) if dat.shape[0] != BS else mdljit(dat) - try: - x,ny = next(iterator) - dat = Tensor(x) - except StopIteration: - dat = None - t = outs.argmax(axis=1).numpy() - et = time.perf_counter() - n += (t==y).sum() - d += len(t) - print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS") + mdlrun = lambda x: mdl(input_fixup(x)).realize() + mdljit = TinyJit(mdlrun) + + # evaluation on the mlperf classes of the validation set from imagenet + from extra.datasets.imagenet import iterate + from extra.helpers import cross_process + + BS = 64 + n, d = 0, 0 st = time.perf_counter() + iterator = cross_process(lambda: iterate(BS)) + x, ny = next(iterator) + dat = Tensor(x) + while dat is not None: + y = ny + GlobalCounters.reset() + mt = time.perf_counter() + outs = mdlrun(dat) if dat.shape[0] != BS else mdljit(dat) + try: + x, ny = next(iterator) + dat = Tensor(x) + except StopIteration: + dat = None + t = outs.argmax(axis=1).numpy() + et = time.perf_counter() + n += (t == y).sum() + d += len(t) + print( + f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS" + ) + st = time.perf_counter() + def eval_unet3d(): - # UNet3D - from extra.models.unet3d import UNet3D - from extra.datasets.kits19 import iterate, sliding_window_inference - from examples.mlperf.metrics import get_dice_score - mdl = UNet3D() - mdl.load_from_pretrained() - s = 0 - st = time.perf_counter() - for i, (image, label) in enumerate(iterate(), start=1): - mt = time.perf_counter() - pred, label = sliding_window_inference(mdl, image, label) - et = time.perf_counter() - print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model") - s += get_dice_score(pred, label).mean() - print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score") + # UNet3D + from extra.models.unet3d import UNet3D + from extra.datasets.kits19 import iterate, sliding_window_inference + from examples.mlperf.metrics import get_dice_score + + mdl = UNet3D() + mdl.load_from_pretrained() + s = 0 st = time.perf_counter() + for i, (image, label) in enumerate(iterate(), start=1): + mt = time.perf_counter() + pred, label = sliding_window_inference(mdl, image, label) + et = time.perf_counter() + print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model") + s += get_dice_score(pred, label).mean() + print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score") + st = time.perf_counter() + def eval_retinanet(): - # RetinaNet with ResNeXt50_32X4D - from extra.models.resnet import ResNeXt50_32X4D - from extra.models.retinanet import RetinaNet - mdl = RetinaNet(ResNeXt50_32X4D()) - mdl.load_from_pretrained() + # RetinaNet with ResNeXt50_32X4D + from extra.models.resnet import ResNeXt50_32X4D + from extra.models.retinanet import RetinaNet - input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) - input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - def input_fixup(x): - x = x.permute([0,3,1,2]) / 255.0 - x -= input_mean - x /= input_std - return x + mdl = RetinaNet(ResNeXt50_32X4D()) + mdl.load_from_pretrained() - from extra.datasets.openimages import openimages, iterate - from pycocotools.coco import COCO - from pycocotools.cocoeval import COCOeval - from contextlib import redirect_stdout - coco = COCO(openimages()) - coco_eval = COCOeval(coco, iouType="bbox") - coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng) + input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) + input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - from tinygrad.jit import TinyJit - mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize()) + def input_fixup(x): + x = x.permute([0, 3, 1, 2]) / 255.0 + x -= input_mean + x /= input_std + return x - n, bs = 0, 8 - st = time.perf_counter() - for x, targets in iterate(coco, bs): - dat = Tensor(x.astype(np.float32)) - mt = time.perf_counter() - if dat.shape[0] == bs: - outs = mdlrun(dat).numpy() - else: - mdlrun.jit_cache = None - outs = mdl(input_fixup(dat)).numpy() - et = time.perf_counter() - predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets]) - ext = time.perf_counter() - n += len(targets) - print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing") - img_ids = [t["image_id"] for t in targets] - coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score} - for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())] - with redirect_stdout(None): - coco_eval.cocoDt = coco.loadRes(coco_results) - coco_eval.params.imgIds = img_ids - coco_eval.evaluate() - evaluated_imgs.extend(img_ids) - coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))) + from extra.datasets.openimages import openimages, iterate + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + from contextlib import redirect_stdout + + coco = COCO(openimages()) + coco_eval = COCOeval(coco, iouType="bbox") + coco_evalimgs, evaluated_imgs, ncats, narea = ( + [], + [], + len(coco_eval.params.catIds), + len(coco_eval.params.areaRng), + ) + + from tinygrad.jit import TinyJit + + mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize()) + + n, bs = 0, 8 st = time.perf_counter() + for x, targets in iterate(coco, bs): + dat = Tensor(x.astype(np.float32)) + mt = time.perf_counter() + if dat.shape[0] == bs: + outs = mdlrun(dat).numpy() + else: + mdlrun.jit_cache = None + outs = mdl(input_fixup(dat)).numpy() + et = time.perf_counter() + predictions = mdl.postprocess_detections( + outs, + input_size=dat.shape[1:3], + orig_image_sizes=[t["image_size"] for t in targets], + ) + ext = time.perf_counter() + n += len(targets) + print( + f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing" + ) + img_ids = [t["image_id"] for t in targets] + coco_results = [ + { + "image_id": targets[i]["image_id"], + "category_id": label, + "bbox": box, + "score": score, + } + for i, prediction in enumerate(predictions) + for box, score, label in zip(*prediction.values()) + ] + with redirect_stdout(None): + coco_eval.cocoDt = coco.loadRes(coco_results) + coco_eval.params.imgIds = img_ids + coco_eval.evaluate() + evaluated_imgs.extend(img_ids) + coco_evalimgs.append( + np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)) + ) + st = time.perf_counter() + + coco_eval.params.imgIds = evaluated_imgs + coco_eval._paramsEval.imgIds = evaluated_imgs + coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten()) + coco_eval.accumulate() + coco_eval.summarize() - coco_eval.params.imgIds = evaluated_imgs - coco_eval._paramsEval.imgIds = evaluated_imgs - coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten()) - coco_eval.accumulate() - coco_eval.summarize() def eval_rnnt(): - # RNN-T - from extra.models.rnnt import RNNT - mdl = RNNT() - mdl.load_from_pretrained() + # RNN-T + from extra.models.rnnt import RNNT - from extra.datasets.librispeech import iterate - from examples.mlperf.metrics import word_error_rate + mdl = RNNT() + mdl.load_from_pretrained() - LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] + from extra.datasets.librispeech import iterate + from examples.mlperf.metrics import word_error_rate - c = 0 - scores = 0 - words = 0 - st = time.perf_counter() - for X, Y in iterate(): - mt = time.perf_counter() - tt = mdl.decode(Tensor(X[0]), Tensor([X[1]])) - et = time.perf_counter() - print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model") - for n, t in enumerate(tt): - tnp = np.array(t) - _, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]]) - scores += scores_ - words += words_ - c += len(tt) - print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}") + LABELS = [ + " ", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "'", + ] + + c = 0 + scores = 0 + words = 0 st = time.perf_counter() + for X, Y in iterate(): + mt = time.perf_counter() + tt = mdl.decode(Tensor(X[0]), Tensor([X[1]])) + et = time.perf_counter() + print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model") + for n, t in enumerate(tt): + tnp = np.array(t) + _, scores_, words_ = word_error_rate( + ["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]] + ) + scores += scores_ + words += words_ + c += len(tt) + print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}") + st = time.perf_counter() + def eval_bert(): - # Bert-QA - from extra.models.bert import BertForQuestionAnswering - mdl = BertForQuestionAnswering() - mdl.load_from_pretrained() + # Bert-QA + from extra.models.bert import BertForQuestionAnswering - @TinyJit - def run(input_ids, input_mask, segment_ids): - return mdl(input_ids, input_mask, segment_ids).realize() + mdl = BertForQuestionAnswering() + mdl.load_from_pretrained() - from extra.datasets.squad import iterate - from examples.mlperf.helpers import get_bert_qa_prediction - from examples.mlperf.metrics import f1_score - from transformers import BertTokenizer + @TinyJit + def run(input_ids, input_mask, segment_ids): + return mdl(input_ids, input_mask, segment_ids).realize() - tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt")) + from extra.datasets.squad import iterate + from examples.mlperf.helpers import get_bert_qa_prediction + from examples.mlperf.metrics import f1_score + from transformers import BertTokenizer - c = 0 - f1 = 0.0 - st = time.perf_counter() - for X, Y in iterate(tokenizer): - mt = time.perf_counter() - outs = [] - for x in X: - outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy()) - et = time.perf_counter() - print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features") - - pred = get_bert_qa_prediction(X, Y, outs) - print(f"pred: {pred}\nans: {Y['answers']}") - f1 += max([f1_score(pred, ans) for ans in Y["answers"]]) - c += 1 - print(f"f1: {f1/c}, raw: {f1}, c: {c}\n") + tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt")) + c = 0 + f1 = 0.0 st = time.perf_counter() + for X, Y in iterate(tokenizer): + mt = time.perf_counter() + outs = [] + for x in X: + outs.append( + run( + Tensor(x["input_ids"]), + Tensor(x["input_mask"]), + Tensor(x["segment_ids"]), + ).numpy() + ) + et = time.perf_counter() + print( + f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features" + ) + + pred = get_bert_qa_prediction(X, Y, outs) + print(f"pred: {pred}\nans: {Y['answers']}") + f1 += max([f1_score(pred, ans) for ans in Y["answers"]]) + c += 1 + print(f"f1: {f1/c}, raw: {f1}, c: {c}\n") + + st = time.perf_counter() + def eval_mrcnn(): - from tqdm import tqdm - from extra.models.mask_rcnn import MaskRCNN - from extra.models.resnet import ResNet - from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate - from examples.mask_rcnn import compute_prediction_batched, Image - mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) - mdl.load_from_pretrained() + from tqdm import tqdm + from extra.models.mask_rcnn import MaskRCNN + from extra.models.resnet import ResNet + from extra.datasets.coco import ( + BASEDIR, + images, + convert_prediction_to_coco_bbox, + convert_prediction_to_coco_mask, + accumulate_predictions_for_coco, + evaluate_predictions_on_coco, + iterate, + ) + from examples.mask_rcnn import compute_prediction_batched, Image - bbox_output = '/tmp/results_bbox.json' - mask_output = '/tmp/results_mask.json' + mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) + mdl.load_from_pretrained() - accumulate_predictions_for_coco([], bbox_output, rm=True) - accumulate_predictions_for_coco([], mask_output, rm=True) + bbox_output = "/tmp/results_bbox.json" + mask_output = "/tmp/results_mask.json" - #TODO: bs > 1 not as accurate - bs = 1 + accumulate_predictions_for_coco([], bbox_output, rm=True) + accumulate_predictions_for_coco([], mask_output, rm=True) - for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs): - batch_imgs = [] - for image_row in batch: - image_name = image_row['file_name'] - img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB") - batch_imgs.append(img) - batch_result = compute_prediction_batched(batch_imgs, mdl) - for image_row, result in zip(batch, batch_result): - image_name = image_row['file_name'] - box_pred = convert_prediction_to_coco_bbox(image_name, result) - mask_pred = convert_prediction_to_coco_mask(image_name, result) - accumulate_predictions_for_coco(box_pred, bbox_output) - accumulate_predictions_for_coco(mask_pred, mask_output) - del batch_imgs - del batch_result + # TODO: bs > 1 not as accurate + bs = 1 + + for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs): + batch_imgs = [] + for image_row in batch: + image_name = image_row["file_name"] + img = Image.open(BASEDIR / f"val2017/{image_name}").convert("RGB") + batch_imgs.append(img) + batch_result = compute_prediction_batched(batch_imgs, mdl) + for image_row, result in zip(batch, batch_result): + image_name = image_row["file_name"] + box_pred = convert_prediction_to_coco_bbox(image_name, result) + mask_pred = convert_prediction_to_coco_mask(image_name, result) + accumulate_predictions_for_coco(box_pred, bbox_output) + accumulate_predictions_for_coco(mask_pred, mask_output) + del batch_imgs + del batch_result + + evaluate_predictions_on_coco(bbox_output, iou_type="bbox") + evaluate_predictions_on_coco(mask_output, iou_type="segm") - evaluate_predictions_on_coco(bbox_output, iou_type='bbox') - evaluate_predictions_on_coco(mask_output, iou_type='segm') if __name__ == "__main__": - # inference only - Tensor.training = False - Tensor.no_grad = True + # inference only + Tensor.training = False + Tensor.no_grad = True - models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",") - for m in models: - nm = f"eval_{m}" - if nm in globals(): - print(f"eval {m}") - globals()[nm]() \ No newline at end of file + models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",") + for m in models: + nm = f"eval_{m}" + if nm in globals(): + print(f"eval {m}") + globals()[nm]() diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index feb737a2f..eac2ad5ad 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -3,68 +3,84 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import GlobalCounters, getenv import numpy as np + def test_model(model, *inputs): - GlobalCounters.reset() - out = model(*inputs) - if isinstance(out, Tensor): out = out.numpy() - # TODO: return event future to still get the time_sum_s without DEBUG=2 - print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms") + GlobalCounters.reset() + out = model(*inputs) + if isinstance(out, Tensor): + out = out.numpy() + # TODO: return event future to still get the time_sum_s without DEBUG=2 + print( + f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms" + ) + def spec_resnet(): - # Resnet50-v1.5 - from extra.models.resnet import ResNet50 - mdl = ResNet50() - img = Tensor.randn(1, 3, 224, 224) - test_model(mdl, img) + # Resnet50-v1.5 + from extra.models.resnet import ResNet50 + + mdl = ResNet50() + img = Tensor.randn(1, 3, 224, 224) + test_model(mdl, img) + def spec_retinanet(): - # Retinanet with ResNet backbone - from extra.models.resnet import ResNet50 - from extra.models.retinanet import RetinaNet - mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9) - img = Tensor.randn(1, 3, 224, 224) - test_model(mdl, img) + # Retinanet with ResNet backbone + from extra.models.resnet import ResNet50 + from extra.models.retinanet import RetinaNet + + mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9) + img = Tensor.randn(1, 3, 224, 224) + test_model(mdl, img) + def spec_unet3d(): - # 3D UNET - from extra.models.unet3d import UNet3D - mdl = UNet3D() - #mdl.load_from_pretrained() - img = Tensor.randn(1, 1, 128, 128, 128) - test_model(mdl, img) + # 3D UNET + from extra.models.unet3d import UNet3D + + mdl = UNet3D() + # mdl.load_from_pretrained() + img = Tensor.randn(1, 1, 128, 128, 128) + test_model(mdl, img) + def spec_rnnt(): - from extra.models.rnnt import RNNT - mdl = RNNT() - #mdl.load_from_pretrained() - x = Tensor.randn(220, 1, 240) - y = Tensor.randn(1, 220) - test_model(mdl, x, y) + from extra.models.rnnt import RNNT + + mdl = RNNT() + # mdl.load_from_pretrained() + x = Tensor.randn(220, 1, 240) + y = Tensor.randn(1, 220) + test_model(mdl, x, y) + def spec_bert(): - from extra.models.bert import BertForQuestionAnswering - mdl = BertForQuestionAnswering() - #mdl.load_from_pretrained() - x = Tensor.randn(1, 384) - am = Tensor.randn(1, 384) - tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32)) - test_model(mdl, x, am, tt) + from extra.models.bert import BertForQuestionAnswering + + mdl = BertForQuestionAnswering() + # mdl.load_from_pretrained() + x = Tensor.randn(1, 384) + am = Tensor.randn(1, 384) + tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32)) + test_model(mdl, x, am, tt) + def spec_mrcnn(): - from extra.models.mask_rcnn import MaskRCNN, ResNet - mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) - #mdl.load_from_pretrained() - x = Tensor.randn(3, 224, 224) - test_model(mdl, [x]) + from extra.models.mask_rcnn import MaskRCNN, ResNet + + mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) + # mdl.load_from_pretrained() + x = Tensor.randn(3, 224, 224) + test_model(mdl, [x]) + if __name__ == "__main__": - # inference only for now - Tensor.training = False - Tensor.no_grad = True - - for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","): - nm = f"spec_{m}" - if nm in globals(): - print(f"testing {m}") - globals()[nm]() + # inference only for now + Tensor.training = False + Tensor.no_grad = True + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","): + nm = f"spec_{m}" + if nm in globals(): + print(f"testing {m}") + globals()[nm]() diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 4ad874236..dca0a87ae 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1,36 +1,43 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import getenv + def train_resnet(): - # TODO: Resnet50-v1.5 - pass + # TODO: Resnet50-v1.5 + pass + def train_retinanet(): - # TODO: Retinanet - pass + # TODO: Retinanet + pass + def train_unet3d(): - # TODO: Unet3d - pass + # TODO: Unet3d + pass + def train_rnnt(): - # TODO: RNN-T - pass + # TODO: RNN-T + pass + def train_bert(): - # TODO: BERT - pass + # TODO: BERT + pass + def train_maskrcnn(): - # TODO: Mask RCNN - pass + # TODO: Mask RCNN + pass + if __name__ == "__main__": - with Tensor.train(): - for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","): - nm = f"train_{m}" - if nm in globals(): - print(f"training {m}") - globals()[nm]() - - + with Tensor.train(): + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split( + "," + ): + nm = f"train_{m}" + if nm in globals(): + print(f"training {m}") + globals()[nm]() diff --git a/examples/mnist_gan.py b/examples/mnist_gan.py index 0d4c0a34c..a2e7e920f 100644 --- a/examples/mnist_gan.py +++ b/examples/mnist_gan.py @@ -9,99 +9,115 @@ from tinygrad.helpers import getenv from tinygrad.nn import optim from extra.datasets import fetch_mnist -class LinearGen: - def __init__(self): - self.l1 = Tensor.scaled_uniform(128, 256) - self.l2 = Tensor.scaled_uniform(256, 512) - self.l3 = Tensor.scaled_uniform(512, 1024) - self.l4 = Tensor.scaled_uniform(1024, 784) - def forward(self, x): - x = x.dot(self.l1).leakyrelu(0.2) - x = x.dot(self.l2).leakyrelu(0.2) - x = x.dot(self.l3).leakyrelu(0.2) - x = x.dot(self.l4).tanh() - return x +class LinearGen: + def __init__(self): + self.l1 = Tensor.scaled_uniform(128, 256) + self.l2 = Tensor.scaled_uniform(256, 512) + self.l3 = Tensor.scaled_uniform(512, 1024) + self.l4 = Tensor.scaled_uniform(1024, 784) + + def forward(self, x): + x = x.dot(self.l1).leakyrelu(0.2) + x = x.dot(self.l2).leakyrelu(0.2) + x = x.dot(self.l3).leakyrelu(0.2) + x = x.dot(self.l4).tanh() + return x + class LinearDisc: - def __init__(self): - self.l1 = Tensor.scaled_uniform(784, 1024) - self.l2 = Tensor.scaled_uniform(1024, 512) - self.l3 = Tensor.scaled_uniform(512, 256) - self.l4 = Tensor.scaled_uniform(256, 2) + def __init__(self): + self.l1 = Tensor.scaled_uniform(784, 1024) + self.l2 = Tensor.scaled_uniform(1024, 512) + self.l3 = Tensor.scaled_uniform(512, 256) + self.l4 = Tensor.scaled_uniform(256, 2) + + def forward(self, x): + # balance the discriminator inputs with const bias (.add(1)) + x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3) + x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3) + x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3) + x = x.dot(self.l4).log_softmax() + return x - def forward(self, x): - # balance the discriminator inputs with const bias (.add(1)) - x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3) - x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3) - x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3) - x = x.dot(self.l4).log_softmax() - return x def make_batch(images): - sample = np.random.randint(0, len(images), size=(batch_size)) - image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0 - return Tensor(image_b) + sample = np.random.randint(0, len(images), size=(batch_size)) + image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0 + return Tensor(image_b) + def make_labels(bs, col, val=-2.0): - y = np.zeros((bs, 2), np.float32) - y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789. - return Tensor(y) + y = np.zeros((bs, 2), np.float32) + y[ + range(bs), [col] * bs + ] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789. + return Tensor(y) + def train_discriminator(optimizer, data_real, data_fake): - real_labels = make_labels(batch_size, 1) - fake_labels = make_labels(batch_size, 0) - optimizer.zero_grad() - output_real = discriminator.forward(data_real) - output_fake = discriminator.forward(data_fake) - loss_real = (output_real * real_labels).mean() - loss_fake = (output_fake * fake_labels).mean() - loss_real.backward() - loss_fake.backward() - optimizer.step() - return (loss_real + loss_fake).numpy() + real_labels = make_labels(batch_size, 1) + fake_labels = make_labels(batch_size, 0) + optimizer.zero_grad() + output_real = discriminator.forward(data_real) + output_fake = discriminator.forward(data_fake) + loss_real = (output_real * real_labels).mean() + loss_fake = (output_fake * fake_labels).mean() + loss_real.backward() + loss_fake.backward() + optimizer.step() + return (loss_real + loss_fake).numpy() + def train_generator(optimizer, data_fake): - real_labels = make_labels(batch_size, 1) - optimizer.zero_grad() - output = discriminator.forward(data_fake) - loss = (output * real_labels).mean() - loss.backward() - optimizer.step() - return loss.numpy() + real_labels = make_labels(batch_size, 1) + optimizer.zero_grad() + output = discriminator.forward(data_fake) + loss = (output * real_labels).mean() + loss.backward() + optimizer.step() + return loss.numpy() + if __name__ == "__main__": - # data for training and validation - images_real = np.vstack(fetch_mnist()[::2]) - ds_noise = Tensor.randn(64, 128, requires_grad=False) - # parameters - epochs, batch_size, k = 300, 512, 1 - sample_interval = epochs // 10 - n_steps = len(images_real) // batch_size - # models and optimizer - generator = LinearGen() - discriminator = LinearDisc() - # path to store results - output_dir = Path(".").resolve() / "outputs" - output_dir.mkdir(exist_ok=True) - # optimizers - optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium! - optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5) - # training loop - for epoch in (t := trange(epochs)): - loss_g, loss_d = 0.0, 0.0 - for _ in range(n_steps): - data_real = make_batch(images_real) - for step in range(k): # Try with k = 5 or 7. - noise = Tensor.randn(batch_size, 128) - data_fake = generator.forward(noise).detach() - loss_d += train_discriminator(optim_d, data_real, data_fake) - noise = Tensor.randn(batch_size, 128) - data_fake = generator.forward(noise) - loss_g += train_generator(optim_g, data_fake) - if (epoch + 1) % sample_interval == 0: - fake_images = generator.forward(ds_noise).detach().numpy() - fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range. - save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg") - t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}") - print("Training Completed!") + # data for training and validation + images_real = np.vstack(fetch_mnist()[::2]) + ds_noise = Tensor.randn(64, 128, requires_grad=False) + # parameters + epochs, batch_size, k = 300, 512, 1 + sample_interval = epochs // 10 + n_steps = len(images_real) // batch_size + # models and optimizer + generator = LinearGen() + discriminator = LinearDisc() + # path to store results + output_dir = Path(".").resolve() / "outputs" + output_dir.mkdir(exist_ok=True) + # optimizers + optim_g = optim.Adam( + get_parameters(generator), lr=0.0002, b1=0.5 + ) # 0.0002 for equilibrium! + optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5) + # training loop + for epoch in (t := trange(epochs)): + loss_g, loss_d = 0.0, 0.0 + for _ in range(n_steps): + data_real = make_batch(images_real) + for step in range(k): # Try with k = 5 or 7. + noise = Tensor.randn(batch_size, 128) + data_fake = generator.forward(noise).detach() + loss_d += train_discriminator(optim_d, data_real, data_fake) + noise = Tensor.randn(batch_size, 128) + data_fake = generator.forward(noise) + loss_g += train_generator(optim_g, data_fake) + if (epoch + 1) % sample_interval == 0: + fake_images = generator.forward(ds_noise).detach().numpy() + fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range. + save_image( + make_grid(torch.tensor(fake_images)), + output_dir / f"image_{epoch+1}.jpg", + ) + t.set_description( + f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}" + ) + print("Training Completed!") diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py index b0c4c69ae..5a9f8d3ea 100644 --- a/examples/serious_mnist.py +++ b/examples/serious_mnist.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb +# inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb import sys import numpy as np from tinygrad.nn.state import get_parameters @@ -9,128 +9,144 @@ from tinygrad.helpers import getenv from extra.datasets import fetch_mnist from extra.augment import augment_img from extra.training import train, evaluate + GPU = getenv("GPU") QUICK = getenv("QUICK") DEBUG = getenv("DEBUG") -class SqueezeExciteBlock2D: - def __init__(self, filters): - self.filters = filters - self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32) - self.bias1 = Tensor.scaled_uniform(1,self.filters//32) - self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters) - self.bias2 = Tensor.scaled_uniform(1, self.filters) - def __call__(self, input): - se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D - se = se.reshape(shape=(-1, self.filters)) - se = se.dot(self.weight1) + self.bias1 - se = se.relu() - se = se.dot(self.weight2) + self.bias2 - se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting - se = input.mul(se) - return se +class SqueezeExciteBlock2D: + def __init__(self, filters): + self.filters = filters + self.weight1 = Tensor.scaled_uniform(self.filters, self.filters // 32) + self.bias1 = Tensor.scaled_uniform(1, self.filters // 32) + self.weight2 = Tensor.scaled_uniform(self.filters // 32, self.filters) + self.bias2 = Tensor.scaled_uniform(1, self.filters) + + def __call__(self, input): + se = input.avg_pool2d( + kernel_size=(input.shape[2], input.shape[3]) + ) # GlobalAveragePool2D + se = se.reshape(shape=(-1, self.filters)) + se = se.dot(self.weight1) + self.bias1 + se = se.relu() + se = se.dot(self.weight2) + self.bias2 + se = se.sigmoid().reshape(shape=(-1, self.filters, 1, 1)) # for broadcasting + se = input.mul(se) + return se + class ConvBlock: - def __init__(self, h, w, inp, filters=128, conv=3): - self.h, self.w = h, w - self.inp = inp - #init weights - self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)] - self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)] - #init layers - self._bn = BatchNorm2d(128) - self._seb = SqueezeExciteBlock2D(filters) + def __init__(self, h, w, inp, filters=128, conv=3): + self.h, self.w = h, w + self.inp = inp + # init weights + self.cweights = [ + Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv) + for i in range(3) + ] + self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)] + # init layers + self._bn = BatchNorm2d(128) + self._seb = SqueezeExciteBlock2D(filters) + + def __call__(self, input): + x = input.reshape(shape=(-1, self.inp, self.w, self.h)) + for cweight, cbias in zip(self.cweights, self.cbiases): + x = x.pad2d(padding=[1, 1, 1, 1]).conv2d(cweight).add(cbias).relu() + x = self._bn(x) + x = self._seb(x) + return x - def __call__(self, input): - x = input.reshape(shape=(-1, self.inp, self.w, self.h)) - for cweight, cbias in zip(self.cweights, self.cbiases): - x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu() - x = self._bn(x) - x = self._seb(x) - return x class BigConvNet: - def __init__(self): - self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)] - self.weight1 = Tensor.scaled_uniform(128,10) - self.weight2 = Tensor.scaled_uniform(128,10) + def __init__(self): + self.conv = [ + ConvBlock(28, 28, 1), + ConvBlock(28, 28, 128), + ConvBlock(14, 14, 128), + ] + self.weight1 = Tensor.scaled_uniform(128, 10) + self.weight2 = Tensor.scaled_uniform(128, 10) - def parameters(self): - if DEBUG: #keeping this for a moment - pars = [par for par in get_parameters(self) if par.requires_grad] - no_pars = 0 - for par in pars: - print(par.shape) - no_pars += np.prod(par.shape) - print('no of parameters', no_pars) - return pars - else: - return get_parameters(self) + def parameters(self): + if DEBUG: # keeping this for a moment + pars = [par for par in get_parameters(self) if par.requires_grad] + no_pars = 0 + for par in pars: + print(par.shape) + no_pars += np.prod(par.shape) + print("no of parameters", no_pars) + return pars + else: + return get_parameters(self) - def save(self, filename): - with open(filename+'.npy', 'wb') as f: - for par in get_parameters(self): - #if par.requires_grad: - np.save(f, par.numpy()) + def save(self, filename): + with open(filename + ".npy", "wb") as f: + for par in get_parameters(self): + # if par.requires_grad: + np.save(f, par.numpy()) - def load(self, filename): - with open(filename+'.npy', 'rb') as f: - for par in get_parameters(self): - #if par.requires_grad: - try: - par.numpy()[:] = np.load(f) - if GPU: - par.gpu() - except: - print('Could not load parameter') + def load(self, filename): + with open(filename + ".npy", "rb") as f: + for par in get_parameters(self): + # if par.requires_grad: + try: + par.numpy()[:] = np.load(f) + if GPU: + par.gpu() + except: + print("Could not load parameter") - def forward(self, x): - x = self.conv[0](x) - x = self.conv[1](x) - x = x.avg_pool2d(kernel_size=(2,2)) - x = self.conv[2](x) - x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global - x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global - xo = x1.dot(self.weight1) + x2.dot(self.weight2) - return xo + def forward(self, x): + x = self.conv[0](x) + x = self.conv[1](x) + x = x.avg_pool2d(kernel_size=(2, 2)) + x = self.conv[2](x) + x1 = x.avg_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global + x2 = x.max_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global + xo = x1.dot(self.weight1) + x2.dot(self.weight2) + return xo if __name__ == "__main__": - lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5] - epochss = [2, 1] if QUICK else [13, 3, 3, 1] - BS = 32 + lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5] + epochss = [2, 1] if QUICK else [13, 3, 3, 1] + BS = 32 - lmbd = 0.00025 - lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum() - X_train, Y_train, X_test, Y_test = fetch_mnist() - X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) - X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) - steps = len(X_train)//BS - np.random.seed(1337) - if QUICK: - steps = 1 - X_test, Y_test = X_test[:BS], Y_test[:BS] + lmbd = 0.00025 + lossfn = ( + lambda out, y: out.sparse_categorical_crossentropy(y) + + lmbd * (model.weight1.abs() + model.weight2.abs()).sum() + ) + X_train, Y_train, X_test, Y_test = fetch_mnist() + X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) + X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) + steps = len(X_train) // BS + np.random.seed(1337) + if QUICK: + steps = 1 + X_test, Y_test = X_test[:BS], Y_test[:BS] - model = BigConvNet() + model = BigConvNet() - if len(sys.argv) > 1: - try: - model.load(sys.argv[1]) - print('Loaded weights "'+sys.argv[1]+'", evaluating...') - evaluate(model, X_test, Y_test, BS=BS) - except: - print('could not load weights "'+sys.argv[1]+'".') + if len(sys.argv) > 1: + try: + model.load(sys.argv[1]) + print('Loaded weights "' + sys.argv[1] + '", evaluating...') + evaluate(model, X_test, Y_test, BS=BS) + except: + print('could not load weights "' + sys.argv[1] + '".') - if GPU: - params = get_parameters(model) - [x.gpu_() for x in params] + if GPU: + params = get_parameters(model) + [x.gpu_() for x in params] - for lr, epochs in zip(lrs, epochss): - optimizer = optim.Adam(model.parameters(), lr=lr) - for epoch in range(1,epochs+1): - #first epoch without augmentation - X_aug = X_train if epoch == 1 else augment_img(X_train) - train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS) - accuracy = evaluate(model, X_test, Y_test, BS=BS) - model.save(f'examples/checkpoint{accuracy * 1e6:.0f}') + for lr, epochs in zip(lrs, epochss): + optimizer = optim.Adam(model.parameters(), lr=lr) + for epoch in range(1, epochs + 1): + # first epoch without augmentation + X_aug = X_train if epoch == 1 else augment_img(X_train) + train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS) + accuracy = evaluate(model, X_test, Y_test, BS=BS) + model.save(f"examples/checkpoint{accuracy * 1e6:.0f}") diff --git a/examples/simple_conv_bn.py b/examples/simple_conv_bn.py index 7d5add4da..58137d780 100644 --- a/examples/simple_conv_bn.py +++ b/examples/simple_conv_bn.py @@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d from tinygrad.nn.state import get_parameters if __name__ == "__main__": - with Tensor.train(): + with Tensor.train(): + BS, C1, H, W = 4, 16, 224, 224 + C2, K, S, P = 64, 7, 2, 1 - BS, C1, H, W = 4, 16, 224, 224 - C2, K, S, P = 64, 7, 2, 1 + x = Tensor.uniform(BS, C1, H, W) + conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) + bn = BatchNorm2d(C2, track_running_stats=False) + for t in get_parameters([x, conv, bn]): + t.realize() - x = Tensor.uniform(BS, C1, H, W) - conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) - bn = BatchNorm2d(C2, track_running_stats=False) - for t in get_parameters([x, conv, bn]): t.realize() - - print("running network") - x.sequential([conv, bn]).numpy() + print("running network") + x.sequential([conv, bn]).numpy() diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py index 3dfa4c2ff..af347a870 100644 --- a/examples/so_vits_svc.py +++ b/examples/so_vits_svc.py @@ -8,7 +8,21 @@ from tinygrad import nn from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, getenv from tinygrad.nn.state import torch_load -from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams +from examples.vits import ( + ResidualCouplingBlock, + PosteriorEncoder, + Encoder, + ResBlock1, + ResBlock2, + LRELU_SLOPE, + sequence_mask, + split, + download_if_not_present, + get_hparams_from_file, + load_checkpoint, + weight_norm, + HParams, +) from examples.sovits_helpers import preprocess import soundfile @@ -20,519 +34,1109 @@ F0_MIN = 50.0 F0_MEL_MIN = 1127 * np.log(1 + F0_MIN / 700) F0_MEL_MAX = 1127 * np.log(1 + F0_MAX / 700) + class SpeechEncoder: - def __init__(self, hidden_dim, model:ContentVec): self.hidden_dim, self.model = hidden_dim, model - def encode(self, ): raise NotImplementedError("implement me") - @classmethod - def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec: - contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url) - return cls(contentvec) + def __init__(self, hidden_dim, model: ContentVec): + self.hidden_dim, self.model = hidden_dim, model + + def encode( + self, + ): + raise NotImplementedError("implement me") + + @classmethod + def load_from_pretrained( + cls, checkpoint_path: str, checkpoint_url: str + ) -> ContentVec: + contentvec = ContentVec.load_from_pretrained(checkpoint_path, checkpoint_url) + return cls(contentvec) + class ContentVec256L9(SpeechEncoder): - def __init__(self, model:ContentVec): super().__init__(hidden_dim=256, model=model) - def encode(self, wav: Tensor): - feats = wav - if len(feats.shape) == 2: # double channels - feats = feats.mean(-1) - assert len(feats.shape) == 1, feats.dim() - feats = feats.reshape(1, -1) - padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) - logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=9) - feats = self.model.final_proj(logits[0]) - return feats.transpose(1,2) + def __init__(self, model: ContentVec): + super().__init__(hidden_dim=256, model=model) + + def encode(self, wav: Tensor): + feats = wav + if len(feats.shape) == 2: # double channels + feats = feats.mean(-1) + assert len(feats.shape) == 1, feats.dim() + feats = feats.reshape(1, -1) + padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) + logits = self.model.extract_features( + feats.to(wav.device), + padding_mask=padding_mask.to(wav.device), + output_layer=9, + ) + feats = self.model.final_proj(logits[0]) + return feats.transpose(1, 2) + class ContentVec768L12(SpeechEncoder): - def __init__(self, model:ContentVec): super().__init__(hidden_dim=768, model=model) - def encode(self, wav: Tensor): - feats = wav - if len(feats.shape) == 2: # double channels - feats = feats.mean(-1) - assert len(feats.shape) == 1, feats.dim() - feats = feats.reshape(1, -1) - padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) - logits = self.model.extract_features(feats.to(wav.device), padding_mask=padding_mask.to(wav.device), output_layer=12) - return logits[0].transpose(1,2) + def __init__(self, model: ContentVec): + super().__init__(hidden_dim=768, model=model) + + def encode(self, wav: Tensor): + feats = wav + if len(feats.shape) == 2: # double channels + feats = feats.mean(-1) + assert len(feats.shape) == 1, feats.dim() + feats = feats.reshape(1, -1) + padding_mask = Tensor.zeros_like(feats).cast(dtypes.bool) + logits = self.model.extract_features( + feats.to(wav.device), + padding_mask=padding_mask.to(wav.device), + output_layer=12, + ) + return logits[0].transpose(1, 2) + # original code for contentvec: https://github.com/auspicious3000/contentvec/ class ContentVec: - # self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2. - # This means that the ContentVec only works with the hubert weights used in all SVC models - def __init__(self, cfg: HParams): - self.feature_grad_mult, self.untie_final_proj = cfg.feature_grad_mult, cfg.untie_final_proj - feature_enc_layers = eval(cfg.conv_feature_layers) - self.embed = feature_enc_layers[-1][0] - final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias) - self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None - self.encoder = TransformerEncoder(cfg) - self.layer_norm = nn.LayerNorm(self.embed) - self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim * 1) if self.untie_final_proj else nn.Linear(cfg.encoder_embed_dim, final_dim) - self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32) - self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32) - def forward_features(self, source, padding_mask): - if self.feature_grad_mult > 0: - features = self.feature_extractor(source, padding_mask) - if self.feature_grad_mult != 1.0: pass # training: GradMultiply.forward(features, self.feature_grad_mult) - else: - features = self.feature_extractor(source, padding_mask) - return features - def forward_padding_mask(self, features, padding_mask): # replaces original forward_padding_mask for batch inference - lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure its bool for tilde - lengths = (lengths_org - 400).float().div(320).floor().cast(dtypes.int64) + 1 # intermediate float to divide - padding_mask = lengths_to_padding_mask(lengths) - return padding_mask - def extract_features(self, source: Tensor, spk_emb:Tensor=None, padding_mask=None, ret_conv=False, output_layer=None, tap=False): - features = self.forward_features(source, padding_mask) - if padding_mask is not None: - padding_mask = self.forward_padding_mask(features, padding_mask) - features = features.transpose(1, 2) - features = self.layer_norm(features) - if self.post_extract_proj is not None: - features = self.post_extract_proj(features) - x, _ = self.encoder(features, spk_emb, padding_mask=padding_mask, layer=(None if output_layer is None else output_layer - 1), tap=tap) - res = features if ret_conv else x - return res, padding_mask - @classmethod - def load_from_pretrained(cls, checkpoint_path:str, checkpoint_url:str) -> ContentVec: - download_if_not_present(checkpoint_path, checkpoint_url) - cfg = load_fairseq_cfg(checkpoint_path) - enc = cls(cfg.model) - _ = load_checkpoint_enc(checkpoint_path, enc, None) - logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}") - return enc + # self.final_proj dims are hardcoded and depend on fairseq.data.dictionary Dictionary in the checkpoint. This param can't yet be loaded since there is no pickle for it. See with DEBUG=2. + # This means that the ContentVec only works with the hubert weights used in all SVC models + def __init__(self, cfg: HParams): + self.feature_grad_mult, self.untie_final_proj = ( + cfg.feature_grad_mult, + cfg.untie_final_proj, + ) + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + self.encoder = TransformerEncoder(cfg) + self.layer_norm = nn.LayerNorm(self.embed) + self.final_proj = ( + nn.Linear(cfg.encoder_embed_dim, final_dim * 1) + if self.untie_final_proj + else nn.Linear(cfg.encoder_embed_dim, final_dim) + ) + self.mask_emb = Tensor.uniform(cfg.encoder_embed_dim, dtype=dtypes.float32) + self.label_embs_concat = Tensor.uniform(504, final_dim, dtype=dtypes.float32) + + def forward_features(self, source, padding_mask): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source, padding_mask) + if self.feature_grad_mult != 1.0: + pass # training: GradMultiply.forward(features, self.feature_grad_mult) + else: + features = self.feature_extractor(source, padding_mask) + return features + + def forward_padding_mask( + self, features, padding_mask + ): # replaces original forward_padding_mask for batch inference + lengths_org = ( + tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) + ) # ensure its bool for tilde + lengths = (lengths_org - 400).float().div(320).floor().cast( + dtypes.int64 + ) + 1 # intermediate float to divide + padding_mask = lengths_to_padding_mask(lengths) + return padding_mask + + def extract_features( + self, + source: Tensor, + spk_emb: Tensor = None, + padding_mask=None, + ret_conv=False, + output_layer=None, + tap=False, + ): + features = self.forward_features(source, padding_mask) + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + features = features.transpose(1, 2) + features = self.layer_norm(features) + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + x, _ = self.encoder( + features, + spk_emb, + padding_mask=padding_mask, + layer=(None if output_layer is None else output_layer - 1), + tap=tap, + ) + res = features if ret_conv else x + return res, padding_mask + + @classmethod + def load_from_pretrained( + cls, checkpoint_path: str, checkpoint_url: str + ) -> ContentVec: + download_if_not_present(checkpoint_path, checkpoint_url) + cfg = load_fairseq_cfg(checkpoint_path) + enc = cls(cfg.model) + _ = load_checkpoint_enc(checkpoint_path, enc, None) + logging.debug(f"{cls.__name__}: Loaded model with cfg={cfg}") + return enc + class TransformerEncoder: - def __init__(self, cfg: HParams): - def make_conv() -> nn.Conv1d: - layer = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=cfg.conv_pos, padding=cfg.conv_pos // 2, groups=cfg.conv_pos_groups) - std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim)) - layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), (Tensor.zeros(*layer.bias.shape)) - # for training: layer.weights need to be weight_normed - return layer - self.dropout, self.embedding_dim, self.layer_norm_first, self.layerdrop, self.num_layers, self.num_layers_1 = cfg.dropout, cfg.encoder_embed_dim, cfg.layer_norm_first, cfg.encoder_layerdrop, cfg.encoder_layers, cfg.encoder_layers_1 - self.pos_conv, self.pos_conv_remove = [make_conv()], (1 if cfg.conv_pos % 2 == 0 else 0) - self.layers = [ - TransformerEncoderLayer(self.embedding_dim, cfg.encoder_ffn_embed_dim, cfg.encoder_attention_heads, self.dropout, cfg.attention_dropout, cfg.activation_dropout, cfg.activation_fn, self.layer_norm_first, cond_layer_norm=(i >= cfg.encoder_layers)) - for i in range(cfg.encoder_layers + cfg.encoder_layers_1) - ] - self.layer_norm = nn.LayerNorm(self.embedding_dim) - self.cond_layer_norm = CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None - # training: apply init_bert_params - def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False): - x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap) - if self.layer_norm_first and layer is None: - x = self.cond_layer_norm(x, spk_emb) if (self.num_layers_1 > 0) else self.layer_norm(x) - return x, layer_results - def extract_features(self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False): - if tgt_layer is not None: # and not self.training - assert tgt_layer >= 0 and tgt_layer < len(self.layers) - if padding_mask is not None: - # x[padding_mask] = 0 - assert padding_mask.shape == x.shape[:len(padding_mask.shape)] # first few dims of x must match padding_mask - tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1])) - tmp_mask = tilde(tmp_mask.cast(dtypes.bool)) - x = tmp_mask.where(x, 0) - x_conv = self.pos_conv[0](x.transpose(1,2)) - if self.pos_conv_remove > 0: x_conv = x_conv[:, :, : -self.pos_conv_remove] - x_conv = x_conv.gelu().transpose(1, 2) - x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C - if not self.layer_norm_first: x = self.layer_norm(x) - x = x.dropout(p=self.dropout) - layer_results = [] - r = None - for i, layer in enumerate(self.layers): - if i < self.num_layers: # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers): - assert layer.cond_layer_norm == False - x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) - if tgt_layer is not None or tap: - layer_results.append(x.transpose(0, 1)) - if i>= self.num_layers: - assert layer.cond_layer_norm == True - x = layer(x, emb=spk_emb, self_attn_padding_mask=padding_mask, need_weights=False) - if i == tgt_layer: - r = x - break - if r is not None: - x = r - x = x.transpose(0, 1) # T x B x C -> B x T x C - return x, layer_results + def __init__(self, cfg: HParams): + def make_conv() -> nn.Conv1d: + layer = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=cfg.conv_pos, + padding=cfg.conv_pos // 2, + groups=cfg.conv_pos_groups, + ) + std = std = math.sqrt(4 / (cfg.conv_pos * self.embedding_dim)) + layer.weight, layer.bias = (Tensor.normal(*layer.weight.shape, std=std)), ( + Tensor.zeros(*layer.bias.shape) + ) + # for training: layer.weights need to be weight_normed + return layer + + ( + self.dropout, + self.embedding_dim, + self.layer_norm_first, + self.layerdrop, + self.num_layers, + self.num_layers_1, + ) = ( + cfg.dropout, + cfg.encoder_embed_dim, + cfg.layer_norm_first, + cfg.encoder_layerdrop, + cfg.encoder_layers, + cfg.encoder_layers_1, + ) + self.pos_conv, self.pos_conv_remove = [make_conv()], ( + 1 if cfg.conv_pos % 2 == 0 else 0 + ) + self.layers = [ + TransformerEncoderLayer( + self.embedding_dim, + cfg.encoder_ffn_embed_dim, + cfg.encoder_attention_heads, + self.dropout, + cfg.attention_dropout, + cfg.activation_dropout, + cfg.activation_fn, + self.layer_norm_first, + cond_layer_norm=(i >= cfg.encoder_layers), + ) + for i in range(cfg.encoder_layers + cfg.encoder_layers_1) + ] + self.layer_norm = nn.LayerNorm(self.embedding_dim) + self.cond_layer_norm = ( + CondLayerNorm(self.embedding_dim) if cfg.encoder_layers_1 > 0 else None + ) + # training: apply init_bert_params + + def __call__(self, x, spk_emb, padding_mask=None, layer=None, tap=False): + x, layer_results = self.extract_features(x, spk_emb, padding_mask, layer, tap) + if self.layer_norm_first and layer is None: + x = ( + self.cond_layer_norm(x, spk_emb) + if (self.num_layers_1 > 0) + else self.layer_norm(x) + ) + return x, layer_results + + def extract_features( + self, x: Tensor, spk_emb: Tensor, padding_mask=None, tgt_layer=None, tap=False + ): + if tgt_layer is not None: # and not self.training + assert tgt_layer >= 0 and tgt_layer < len(self.layers) + if padding_mask is not None: + # x[padding_mask] = 0 + assert ( + padding_mask.shape == x.shape[: len(padding_mask.shape)] + ) # first few dims of x must match padding_mask + tmp_mask = padding_mask.unsqueeze(-1).repeat((1, 1, x.shape[-1])) + tmp_mask = tilde(tmp_mask.cast(dtypes.bool)) + x = tmp_mask.where(x, 0) + x_conv = self.pos_conv[0](x.transpose(1, 2)) + if self.pos_conv_remove > 0: + x_conv = x_conv[:, :, : -self.pos_conv_remove] + x_conv = x_conv.gelu().transpose(1, 2) + x = (x + x_conv).transpose(0, 1) # B x T x C -> T x B x C + if not self.layer_norm_first: + x = self.layer_norm(x) + x = x.dropout(p=self.dropout) + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + if ( + i < self.num_layers + ): # if (not self.training or (dropout_probability > self.layerdrop)) and (i < self.num_layers): + assert layer.cond_layer_norm == False + x = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + if tgt_layer is not None or tap: + layer_results.append(x.transpose(0, 1)) + if i >= self.num_layers: + assert layer.cond_layer_norm == True + x = layer( + x, + emb=spk_emb, + self_attn_padding_mask=padding_mask, + need_weights=False, + ) + if i == tgt_layer: + r = x + break + if r is not None: + x = r + x = x.transpose(0, 1) # T x B x C -> B x T x C + return x, layer_results + class TransformerEncoderLayer: - def __init__(self, embedding_dim=768.0, ffn_embedding_dim=3072.0, num_attention_heads=8.0, dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, activation_fn="relu", layer_norm_first=False, cond_layer_norm=False): - def get_activation_fn(activation): - if activation == "relu": return Tensor.relu - if activation == "gelu": return Tensor.gelu - else: raise RuntimeError(f"activation function={activation} is not forseen") - self.embedding_dim, self.dropout, self.activation_dropout, self.layer_norm_first, self.num_attention_heads, self.cond_layer_norm, self.activation_fn = embedding_dim, dropout, activation_dropout, layer_norm_first, num_attention_heads, cond_layer_norm, get_activation_fn(activation_fn) - self.self_attn = MultiHeadAttention(self.embedding_dim, self.num_attention_heads) - self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim) - self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) - self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) - self.final_layer_norm = nn.LayerNorm(self.embedding_dim) if not cond_layer_norm else CondLayerNorm(self.embedding_dim) - def __call__(self, x:Tensor, self_attn_mask:Tensor=None, self_attn_padding_mask:Tensor=None, emb:Tensor=None, need_weights=False): - #self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None - assert self_attn_mask is None and self_attn_padding_mask is not None - residual = x - if self.layer_norm_first: - x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb) - x = self.self_attn(x=x, mask=self_attn_padding_mask) - x = x.dropout(self.dropout) - x = residual + x - x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb) - x = self.activation_fn(self.fc1(x)) - x = x.dropout(self.activation_dropout) - x = self.fc2(x) - x = x.dropout(self.dropout) - x = residual + x - else: - x = self.self_attn(x=x, mask=self_attn_padding_mask) - x = x.dropout(self.dropout) - x = residual + x - x = self.self_attn_layer_norm(x) if not self.cond_layer_norm else self.self_attn_layer_norm(x, emb) - residual = x - x = self.activation_fn(self.fc1(x)) - x = x.dropout(self.activation_dropout) - x = self.fc2(x) - x = x.dropout(self.dropout) - x = residual + x - x = self.final_layer_norm(x) if not self.cond_layer_norm else self.final_layer_norm(x, emb) - return x + def __init__( + self, + embedding_dim=768.0, + ffn_embedding_dim=3072.0, + num_attention_heads=8.0, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + activation_fn="relu", + layer_norm_first=False, + cond_layer_norm=False, + ): + def get_activation_fn(activation): + if activation == "relu": + return Tensor.relu + if activation == "gelu": + return Tensor.gelu + else: + raise RuntimeError(f"activation function={activation} is not forseen") + + ( + self.embedding_dim, + self.dropout, + self.activation_dropout, + self.layer_norm_first, + self.num_attention_heads, + self.cond_layer_norm, + self.activation_fn, + ) = ( + embedding_dim, + dropout, + activation_dropout, + layer_norm_first, + num_attention_heads, + cond_layer_norm, + get_activation_fn(activation_fn), + ) + self.self_attn = MultiHeadAttention( + self.embedding_dim, self.num_attention_heads + ) + self.self_attn_layer_norm = ( + nn.LayerNorm(self.embedding_dim) + if not cond_layer_norm + else CondLayerNorm(self.embedding_dim) + ) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + self.final_layer_norm = ( + nn.LayerNorm(self.embedding_dim) + if not cond_layer_norm + else CondLayerNorm(self.embedding_dim) + ) + + def __call__( + self, + x: Tensor, + self_attn_mask: Tensor = None, + self_attn_padding_mask: Tensor = None, + emb: Tensor = None, + need_weights=False, + ): + # self_attn_padding_mask = self_attn_padding_mask.reshape(x.shape[0], 1, 1, self_attn_padding_mask.shape[1]).expand(-1, self.num_attention_heads, -1, -1).reshape(x.shape[0] * self.num_attention_heads, 1, self_attn_padding_mask.shape[1]) if self_attn_padding_mask is not None else None + assert self_attn_mask is None and self_attn_padding_mask is not None + residual = x + if self.layer_norm_first: + x = ( + self.self_attn_layer_norm(x) + if not self.cond_layer_norm + else self.self_attn_layer_norm(x, emb) + ) + x = self.self_attn(x=x, mask=self_attn_padding_mask) + x = x.dropout(self.dropout) + x = residual + x + x = ( + self.final_layer_norm(x) + if not self.cond_layer_norm + else self.final_layer_norm(x, emb) + ) + x = self.activation_fn(self.fc1(x)) + x = x.dropout(self.activation_dropout) + x = self.fc2(x) + x = x.dropout(self.dropout) + x = residual + x + else: + x = self.self_attn(x=x, mask=self_attn_padding_mask) + x = x.dropout(self.dropout) + x = residual + x + x = ( + self.self_attn_layer_norm(x) + if not self.cond_layer_norm + else self.self_attn_layer_norm(x, emb) + ) + residual = x + x = self.activation_fn(self.fc1(x)) + x = x.dropout(self.activation_dropout) + x = self.fc2(x) + x = x.dropout(self.dropout) + x = residual + x + x = ( + self.final_layer_norm(x) + if not self.cond_layer_norm + else self.final_layer_norm(x, emb) + ) + return x + class MultiHeadAttention: - def __init__(self, n_state, n_head): - self.n_state, self.n_head = n_state, n_head - self.q_proj, self.k_proj, self.v_proj, self.out_proj = [nn.Linear(n_state, n_state) for _ in range(4)] - def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None): - x = x.transpose(0,1) # TxBxC -> BxTxC - q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x) - q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)] - wv = Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None).transpose(1, 2).reshape(*x.shape[:2], -1) - ret = self.out_proj(wv).transpose(0,1) # BxTxC -> TxBxC - return ret + def __init__(self, n_state, n_head): + self.n_state, self.n_head = n_state, n_head + self.q_proj, self.k_proj, self.v_proj, self.out_proj = [ + nn.Linear(n_state, n_state) for _ in range(4) + ] + + def __call__( + self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None + ): + x = x.transpose(0, 1) # TxBxC -> BxTxC + q, k, v = self.q_proj(x), self.k_proj(xa or x), self.v_proj(xa or x) + q, k, v = [x.reshape(*q.shape[:2], self.n_head, -1) for x in (q, k, v)] + wv = ( + Tensor.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), None + ) + .transpose(1, 2) + .reshape(*x.shape[:2], -1) + ) + ret = self.out_proj(wv).transpose(0, 1) # BxTxC -> TxBxC + return ret + class ConvFeatureExtractionModel: - def __init__(self, conv_layers, dropout=.0, mode="default", conv_bias=False): - assert mode in {"default", "group_norm_masked", "layer_norm"} - def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False): - def make_conv(): - conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) - conv.weight = Tensor.kaiming_normal(*conv.weight.shape) - return conv - assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" - if is_layer_norm: - return [make_conv(), partial(Tensor.dropout, p=dropout),[partial(Tensor.transpose, ax1=-2, ax2=-1), nn.LayerNorm(dim, elementwise_affine=True), partial(Tensor.transpose, ax1=-2, ax2=-1)], Tensor.gelu] - elif is_group_norm and mode == "default": - return [make_conv(), partial(Tensor.dropout, p=dropout), nn.GroupNorm(dim, dim, affine=True), Tensor.gelu] - elif is_group_norm and mode == "group_norm_masked": - return [make_conv(), partial(Tensor.dropout, p=dropout), GroupNormMasked(dim, dim, affine=True), Tensor.gelu] - else: - return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu] - in_d, self.conv_layers, self.mode = 1, [], mode - for i, cl in enumerate(conv_layers): - assert len(cl) == 3, "invalid conv definition: " + str(cl) - (dim, k, stride) = cl - if i == 0: self.cl = cl - self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=(mode == "layer_norm"), is_group_norm=((mode == "default" or mode == "group_norm_masked") and i == 0), conv_bias=conv_bias)) - in_d = dim - def __call__(self, x:Tensor, padding_mask:Tensor): - x = x.unsqueeze(1) # BxT -> BxCxT - if self.mode == "group_norm_masked": - if padding_mask is not None: - _, k, stride = self.cl - lengths_org = tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) # ensure padding_mask is bool for tilde - lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64) - padding_mask = tilde(lengths_to_padding_mask(lengths)).cast(dtypes.int64) # lengths_to_padding_mask returns bool tensor - x = self.conv_layers[0][0](x) # padding_mask is numeric - x = self.conv_layers[0][1](x) - x = self.conv_layers[0][2](x, padding_mask) - x = self.conv_layers[0][3](x) - else: - x = x.sequential(self.conv_layers[0]) # default - for _, conv in enumerate(self.conv_layers[1:], start=1): - conv = reduce(lambda a,b: operator.iconcat(a,b if isinstance(b, list) else [b]), conv, []) # flatten - x = x.sequential(conv) - return x + def __init__(self, conv_layers, dropout=0.0, mode="default", conv_bias=False): + assert mode in {"default", "group_norm_masked", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + conv.weight = Tensor.kaiming_normal(*conv.weight.shape) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + if is_layer_norm: + return [ + make_conv(), + partial(Tensor.dropout, p=dropout), + [ + partial(Tensor.transpose, ax1=-2, ax2=-1), + nn.LayerNorm(dim, elementwise_affine=True), + partial(Tensor.transpose, ax1=-2, ax2=-1), + ], + Tensor.gelu, + ] + elif is_group_norm and mode == "default": + return [ + make_conv(), + partial(Tensor.dropout, p=dropout), + nn.GroupNorm(dim, dim, affine=True), + Tensor.gelu, + ] + elif is_group_norm and mode == "group_norm_masked": + return [ + make_conv(), + partial(Tensor.dropout, p=dropout), + GroupNormMasked(dim, dim, affine=True), + Tensor.gelu, + ] + else: + return [make_conv(), partial(Tensor.dropout, p=dropout), Tensor.gelu] + + in_d, self.conv_layers, self.mode = 1, [], mode + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + if i == 0: + self.cl = cl + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=(mode == "layer_norm"), + is_group_norm=( + (mode == "default" or mode == "group_norm_masked") and i == 0 + ), + conv_bias=conv_bias, + ) + ) + in_d = dim + + def __call__(self, x: Tensor, padding_mask: Tensor): + x = x.unsqueeze(1) # BxT -> BxCxT + if self.mode == "group_norm_masked": + if padding_mask is not None: + _, k, stride = self.cl + lengths_org = ( + tilde(padding_mask.cast(dtypes.bool)).cast(dtypes.int64).sum(1) + ) # ensure padding_mask is bool for tilde + lengths = (((lengths_org - k) / stride) + 1).floor().cast(dtypes.int64) + padding_mask = tilde(lengths_to_padding_mask(lengths)).cast( + dtypes.int64 + ) # lengths_to_padding_mask returns bool tensor + x = self.conv_layers[0][0](x) # padding_mask is numeric + x = self.conv_layers[0][1](x) + x = self.conv_layers[0][2](x, padding_mask) + x = self.conv_layers[0][3](x) + else: + x = x.sequential(self.conv_layers[0]) # default + for _, conv in enumerate(self.conv_layers[1:], start=1): + conv = reduce( + lambda a, b: operator.iconcat(a, b if isinstance(b, list) else [b]), + conv, + [], + ) # flatten + x = x.sequential(conv) + return x + class CondLayerNorm: # https://github.com/auspicious3000/contentvec/blob/main/contentvec/modules/cond_layer_norm.py#L10 - def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True): - self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = dim_last, eps, dim_spk, elementwise_affine - if self.elementwise_affine: - self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) - self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) - self.weight_ln.weight, self.bias_ln.weight = (Tensor.ones(*self.weight_ln.weight.shape)), (Tensor.zeros(*self.bias_ln.weight.shape)) - def __call__(self, x: Tensor, spk_emb: Tensor): - axis = tuple(-1-i for i in range(len(x.shape[1:]))) - x = x.layernorm(axis=axis, eps=self.eps) - if not self.elementwise_affine: return x - weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb) - return weights * x + bias + def __init__(self, dim_last, eps=1e-5, dim_spk=256, elementwise_affine=True): + self.dim_last, self.eps, self.dim_spk, self.elementwise_affine = ( + dim_last, + eps, + dim_spk, + elementwise_affine, + ) + if self.elementwise_affine: + self.weight_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) + self.bias_ln = nn.Linear(self.dim_spk, self.dim_last, bias=False) + self.weight_ln.weight, self.bias_ln.weight = ( + Tensor.ones(*self.weight_ln.weight.shape) + ), (Tensor.zeros(*self.bias_ln.weight.shape)) + + def __call__(self, x: Tensor, spk_emb: Tensor): + axis = tuple(-1 - i for i in range(len(x.shape[1:]))) + x = x.layernorm(axis=axis, eps=self.eps) + if not self.elementwise_affine: + return x + weights, bias = self.weight_ln(spk_emb), self.bias_ln(spk_emb) + return weights * x + bias + class GroupNormMasked: # https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/modules/fp32_group_norm.py#L16 - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - self.num_groups, self.num_channels, self.eps, self.affine = num_groups, num_channels, eps, affine - self.weight, self.bias = (Tensor.ones(num_channels)), (Tensor.zeros(num_channels)) if self.affine else (None, None) - def __call__(self, x:Tensor, mask:Tensor): - bsz, n_c, length = x.shape - assert n_c % self.num_groups == 0 - x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length) - if mask is None: mask = Tensor.ones_like(x) - else: mask = mask.reshape(bsz, 1, 1, length) - x = x * mask - lengths = mask.sum(axis=3, keepdim=True) - assert x.shape[2] == 1 - mean_ = x.mean(dim=3, keepdim=True) - mean = mean_ * length / lengths - var = (((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths - mean**2) + self.eps - return x.add(-mean).div(var.sqrt()).reshape(bsz, n_c, length).mul(self.weight.reshape(1,-1,1)).add(self.bias.reshape(1,-1,1)) + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + self.num_groups, self.num_channels, self.eps, self.affine = ( + num_groups, + num_channels, + eps, + affine, + ) + self.weight, self.bias = (Tensor.ones(num_channels)), ( + Tensor.zeros(num_channels) + ) if self.affine else (None, None) + + def __call__(self, x: Tensor, mask: Tensor): + bsz, n_c, length = x.shape + assert n_c % self.num_groups == 0 + x = x.reshape(bsz, self.num_groups, n_c // self.num_groups, length) + if mask is None: + mask = Tensor.ones_like(x) + else: + mask = mask.reshape(bsz, 1, 1, length) + x = x * mask + lengths = mask.sum(axis=3, keepdim=True) + assert x.shape[2] == 1 + mean_ = x.mean(dim=3, keepdim=True) + mean = mean_ * length / lengths + var = ( + ((x.std(axis=3, keepdim=True) ** 2) + mean_**2) * length / lengths + - mean**2 + ) + self.eps + return ( + x.add(-mean) + .div(var.sqrt()) + .reshape(bsz, n_c, length) + .mul(self.weight.reshape(1, -1, 1)) + .add(self.bias.reshape(1, -1, 1)) + ) + class Synthesizer: - def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, ssl_dim, n_speakers, sampling_rate=44100, vol_embedding=False, n_flow_layer=4, **kwargs): - self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.vol_embedding = spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, vol_embedding - self.emb_g = nn.Embedding(n_speakers, gin_channels) - if vol_embedding: self.emb_vol = nn.Linear(1, hidden_channels) - self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) - self.enc_p = TextEncoder(inter_channels, hidden_channels, kernel_size, n_layers, filter_channels=filter_channels, n_heads=n_heads, p_dropout=p_dropout) - self.dec = Generator(sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels) - self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels) - def infer(self, c:Tensor, f0:Tensor, uv:Tensor, g:Tensor=None, noise_scale=0.35, seed=52468, vol=None) -> Tuple[Tensor, Tensor]: - Tensor.manual_seed(getenv('SEED', seed)) - c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device) - if len(g.shape) == 1: g = g.unsqueeze(0) - g = self.emb_g(g).transpose(1, 2) - x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype) - vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0 - x = self.pre(c) * x_mask + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + vol - z_p, _, _, c_mask = self.enc_p.forward(x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale) - z = self.flow.forward(z_p, c_mask, g=g, reverse=True) - o = self.dec.forward(z * c_mask, g=g, f0=f0) - return o,f0 - def _f0_to_coarse(self, f0 : Tensor): - f0_mel = 1127 * (1 + f0 / 700).log() - a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) - b = F0_MEL_MIN * a - 1. - f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel) - f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64) - f0_coarse = f0_coarse * (f0_coarse > 0) - f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) - f0_coarse = f0_coarse * (f0_coarse < F0_BIN) - f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1)) - return f0_coarse - @classmethod - def load_from_pretrained(cls, config_path:str, config_url:str, weights_path:str, weights_url:str) -> Synthesizer: - download_if_not_present(config_path, config_url) - hps = get_hparams_from_file(config_path) - download_if_not_present(weights_path, weights_url) - net_g = cls(hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) - _ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"]) - logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}") - return net_g, hps + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + n_speakers, + sampling_rate=44100, + vol_embedding=False, + n_flow_layer=4, + **kwargs, + ): + ( + self.spec_channels, + self.inter_channels, + self.hidden_channels, + self.filter_channels, + self.n_heads, + self.n_layers, + self.kernel_size, + self.p_dropout, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + self.segment_size, + self.n_speakers, + self.gin_channels, + self.vol_embedding, + ) = ( + spec_channels, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + segment_size, + n_speakers, + gin_channels, + vol_embedding, + ) + self.emb_g = nn.Embedding(n_speakers, gin_channels) + if vol_embedding: + self.emb_vol = nn.Linear(1, hidden_channels) + self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + kernel_size, + n_layers, + filter_channels=filter_channels, + n_heads=n_heads, + p_dropout=p_dropout, + ) + self.dec = Generator( + sampling_rate, + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, + hidden_channels, + 5, + 1, + n_flow_layer, + gin_channels=gin_channels, + ) + self.emb_uv = nn.Embedding(vocab_size=2, embed_size=hidden_channels) + + def infer( + self, + c: Tensor, + f0: Tensor, + uv: Tensor, + g: Tensor = None, + noise_scale=0.35, + seed=52468, + vol=None, + ) -> Tuple[Tensor, Tensor]: + Tensor.manual_seed(getenv("SEED", seed)) + c_lengths = (Tensor.ones([c.shape[0]]) * c.shape[-1]).to(c.device) + if len(g.shape) == 1: + g = g.unsqueeze(0) + g = self.emb_g(g).transpose(1, 2) + x_mask = sequence_mask(c_lengths, c.shape[2]).unsqueeze(1).cast(c.dtype) + vol = ( + self.emb_vol(vol[:, :, None]).transpose(1, 2) + if vol is not None and self.vol_embedding + else 0 + ) + x = ( + self.pre(c) * x_mask + + self.emb_uv(uv.cast(dtypes.int64)).transpose(1, 2) + + vol + ) + z_p, _, _, c_mask = self.enc_p.forward( + x, x_mask, f0=self._f0_to_coarse(f0), noise_scale=noise_scale + ) + z = self.flow.forward(z_p, c_mask, g=g, reverse=True) + o = self.dec.forward(z * c_mask, g=g, f0=f0) + return o, f0 + + def _f0_to_coarse(self, f0: Tensor): + f0_mel = 1127 * (1 + f0 / 700).log() + a = (F0_BIN - 2) / (F0_MEL_MAX - F0_MEL_MIN) + b = F0_MEL_MIN * a - 1.0 + f0_mel = (f0_mel > 0).where(f0_mel * a - b, f0_mel) + f0_coarse = f0_mel.ceil().cast(dtype=dtypes.int64) + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < F0_BIN) + f0_coarse = f0_coarse + ((f0_coarse >= F0_BIN) * (F0_BIN - 1)) + return f0_coarse + + @classmethod + def load_from_pretrained( + cls, config_path: str, config_url: str, weights_path: str, weights_url: str + ) -> Synthesizer: + download_if_not_present(config_path, config_url) + hps = get_hparams_from_file(config_path) + download_if_not_present(weights_path, weights_url) + net_g = cls( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model, + ) + _ = load_checkpoint(weights_path, net_g, None, skip_list=["f0_decoder"]) + logging.debug(f"{cls.__name__}:Loaded model with hps: {hps}") + return net_g, hps + class TextEncoder: - def __init__(self, out_channels, hidden_channels, kernel_size, n_layers, gin_channels=0, filter_channels=None, n_heads=None, p_dropout=None): - self.out_channels, self.hidden_channels, self.kernel_size, self.n_layers, self.gin_channels = out_channels, hidden_channels, kernel_size, n_layers, gin_channels - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256 - self.enc_ = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) - def forward(self, x, x_mask, f0=None, noise_scale=1): - x = x + self.f0_emb(f0).transpose(1, 2) - x = self.enc_.forward(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - m, logs = split(stats, self.out_channels, dim=1) - z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask - return z, m, logs, x_mask + def __init__( + self, + out_channels, + hidden_channels, + kernel_size, + n_layers, + gin_channels=0, + filter_channels=None, + n_heads=None, + p_dropout=None, + ): + ( + self.out_channels, + self.hidden_channels, + self.kernel_size, + self.n_layers, + self.gin_channels, + ) = (out_channels, hidden_channels, kernel_size, n_layers, gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + self.f0_emb = nn.Embedding(256, hidden_channels) # n_vocab = 256 + self.enc_ = Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + + def forward(self, x, x_mask, f0=None, noise_scale=1): + x = x + self.f0_emb(f0).transpose(1, 2) + x = self.enc_.forward(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + m, logs = split(stats, self.out_channels, dim=1) + z = (m + randn_like(m) * logs.exp() * noise_scale) * x_mask + return z, m, logs, x_mask + class Upsample: - def __init__(self, scale_factor): - assert scale_factor % 1 == 0, "Only integer scale factor allowed." - self.scale = int(scale_factor) - def forward(self, x:Tensor): - repeats = tuple([1] * len(x.shape) + [self.scale]) - new_shape = (*x.shape[:-1], x.shape[-1] * self.scale) - return x.unsqueeze(-1).repeat(repeats).reshape(new_shape) + def __init__(self, scale_factor): + assert scale_factor % 1 == 0, "Only integer scale factor allowed." + self.scale = int(scale_factor) + + def forward(self, x: Tensor): + repeats = tuple([1] * len(x.shape) + [self.scale]) + new_shape = (*x.shape[:-1], x.shape[-1] * self.scale) + return x.unsqueeze(-1).repeat(repeats).reshape(new_shape) + class SineGen: - def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voice_threshold=0, flag_for_pulse=False): - self.sine_amp, self.noise_std, self.harmonic_num, self.sampling_rate, self.voiced_threshold, self.flag_for_pulse = sine_amp, noise_std, harmonic_num, samp_rate, voice_threshold, flag_for_pulse - self.dim = self.harmonic_num + 1 - def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal - def _f02sine(self, f0_values): - def padDiff(x : Tensor): return (x.pad2d((0,0,-1,1)) - x).pad2d((0,0,0,-1)) - def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch. - rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad - rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voice_threshold=0, + flag_for_pulse=False, + ): + ( + self.sine_amp, + self.noise_std, + self.harmonic_num, + self.sampling_rate, + self.voiced_threshold, + self.flag_for_pulse, + ) = ( + sine_amp, + noise_std, + harmonic_num, + samp_rate, + voice_threshold, + flag_for_pulse, + ) + self.dim = self.harmonic_num + 1 - #rand_ini[:, 0] = 0 - m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad2d((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool) - m = tilde(m) - rand_ini = m.where(rand_ini, 0) + def _f02uv(self, f0): + return (f0 > self.voiced_threshold).float() # generate uv signal - #rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini - tmp = rad_values[:, 0, :] + rand_ini - m = Tensor.ones(tmp.shape).pad2d((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool) - m = tilde(m) - tmp = tmp.unsqueeze(1).pad2d((0,0,0,rad_values.shape[1]-1,0)) - rad_values = m.where(rad_values, tmp) + def _f02sine(self, f0_values): + def padDiff(x: Tensor): + return (x.pad2d((0, 0, -1, 1)) - x).pad2d((0, 0, 0, -1)) - tmp_over_one = mod(rad_values.cumsum(1), 1) - tmp_over_one_idx = padDiff(tmp_over_one) < 0 - cumsum_shift = Tensor.zeros_like(rad_values) + def mod(x: Tensor, n: int) -> Tensor: + return ( + x - n * x.div(n).floor() + ) # this is what the % operator does in pytorch. - #cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad2d((0,0,1,0)) - cumsum_shift = tmp_over_one_idx + rad_values = mod((f0_values / self.sampling_rate), 1) # convert to F0 in rad + rand_ini = Tensor.rand( + f0_values.shape[0], f0_values.shape[2], device=f0_values.device + ) # initial phase noise + + # rand_ini[:, 0] = 0 + m = ( + Tensor.ones(f0_values.shape[0]) + .unsqueeze(1) + .pad2d((0, f0_values.shape[2] - 1, 0, 0)) + .cast(dtypes.bool) + ) + m = tilde(m) + rand_ini = m.where(rand_ini, 0) + + # rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp = rad_values[:, 0, :] + rand_ini + m = ( + Tensor.ones(tmp.shape) + .pad2d((0, 0, 0, rad_values.shape[1] - 1, 0)) + .cast(dtypes.bool) + ) + m = tilde(m) + tmp = tmp.unsqueeze(1).pad2d((0, 0, 0, rad_values.shape[1] - 1, 0)) + rad_values = m.where(rad_values, tmp) + + tmp_over_one = mod(rad_values.cumsum(1), 1) + tmp_over_one_idx = padDiff(tmp_over_one) < 0 + cumsum_shift = Tensor.zeros_like(rad_values) + + # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad2d((0, 0, 1, 0)) + cumsum_shift = tmp_over_one_idx + + sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin() + return sines + + def forward(self, f0, upp=None): + fn = f0.mul( + Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to( + f0.device + ) + ) + sine_waves = self._f02sine(fn) * self.sine_amp # generate sine waveforms + uv = self._f02uv(f0) # generate uv signal + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise - sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin() - return sines - def forward(self, f0, upp=None): - fn = f0.mul(Tensor([[range(1, self.harmonic_num + 2)]], dtype=dtypes.float32).to(f0.device)) - sine_waves = self._f02sine(fn) * self.sine_amp #generate sine waveforms - uv = self._f02uv(f0) # generate uv signal - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * randn_like(sine_waves) - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise class SourceHnNSF: - def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0): - self.sine_amp, self.noise_std = sine_amp, add_noise_std - self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold) - self.l_linear = nn.Linear(harmonic_num + 1, 1) - def forward(self, x, upp=None): - sine_waves, uv, _ = self.l_sin_gen.forward(x, upp) - sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh() - noise = randn_like(uv) * self.sine_amp / 3 - return sine_merge, noise, uv + def __init__( + self, + sampling_rate, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshold=0, + ): + self.sine_amp, self.noise_std = sine_amp, add_noise_std + self.l_sin_gen = SineGen( + sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold + ) + self.l_linear = nn.Linear(harmonic_num + 1, 1) + + def forward(self, x, upp=None): + sine_waves, uv, _ = self.l_sin_gen.forward(x, upp) + sine_merge = self.l_linear(sine_waves.cast(self.l_linear.weight.dtype)).tanh() + noise = randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + # most of the hifigan in standard vits is reused here, but need to upsample and construct harmonic source from f0 class Generator: - def __init__(self, sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels): - self.sampling_rate, self.inter_channels, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.gin_channels = sampling_rate, inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels - self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates) - self.conv_pre = nn.Conv1d(inter_channels, upsample_initial_channel, 7, 1, padding=3) - self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates)) - self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8) - resblock = ResBlock1 if resblock == '1' else ResBlock2 - self.ups, self.noise_convs, self.resblocks = [], [], [] - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - c_cur = upsample_initial_channel//(2**(i+1)) - self.ups.append(nn.ConvTranspose1d(upsample_initial_channel//(2**i), c_cur, k, u, padding=(k-u)//2)) - stride_f0 = int(np.prod(upsample_rates[i + 1:])) - self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2) if (i + 1 < len(upsample_rates)) else nn.Conv1d(1, c_cur, kernel_size=1)) - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3) - if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - self.upp = np.prod(upsample_rates) - def forward(self, x, f0, g=None): - f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t - har_source, _, _ = self.m_source.forward(f0, self.upp) - har_source = har_source.transpose(1, 2) - x = self.conv_pre(x) - if g is not None: x = x + self.cond(g) - for i in range(self.num_upsamples): - x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None - x_source = self.noise_convs[i](har_source) - x = x + x_source - for j in range(self.num_kernels): - if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x) - else: xs += self.resblocks[i * self.num_kernels + j].forward(x) - x = xs / self.num_kernels - return self.conv_post(x.leakyrelu()).tanh() + def __init__( + self, + sampling_rate, + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ): + ( + self.sampling_rate, + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + self.gin_channels, + ) = ( + sampling_rate, + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ) + self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len( + upsample_rates + ) + self.conv_pre = nn.Conv1d( + inter_channels, upsample_initial_channel, 7, 1, padding=3 + ) + self.f0_upsamp = Upsample(scale_factor=np.prod(upsample_rates)) + self.m_source = SourceHnNSF(sampling_rate, harmonic_num=8) + resblock = ResBlock1 if resblock == "1" else ResBlock2 + self.ups, self.noise_convs, self.resblocks = [], [], [] + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + c_cur, + k, + u, + padding=(k - u) // 2, + ) + ) + stride_f0 = int(np.prod(upsample_rates[i + 1 :])) + self.noise_convs.append( + nn.Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=(stride_f0 + 1) // 2, + ) + if (i + 1 < len(upsample_rates)) + else nn.Conv1d(1, c_cur, kernel_size=1) + ) + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + self.upp = np.prod(upsample_rates) + + def forward(self, x, f0, g=None): + f0 = self.f0_upsamp.forward(f0[:, None]).transpose(1, 2) # bs,n,t + har_source, _, _ = self.m_source.forward(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + for i in range(self.num_upsamples): + x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None + x_source = self.noise_convs[i](har_source) + x = x + x_source + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j].forward(x) + else: + xs += self.resblocks[i * self.num_kernels + j].forward(x) + x = xs / self.num_kernels + return self.conv_post(x.leakyrelu()).tanh() + # **** helpers **** -def randn_like(x:Tensor) -> Tensor: return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device) + +def randn_like(x: Tensor) -> Tensor: + return Tensor.randn(*x.shape, dtype=x.dtype).to(device=x.device) + def tilde(x: Tensor) -> Tensor: - if x.dtype == dtypes.bool: return (1 - x).cast(dtypes.bool) - return (x + 1) * -1 # this seems to be what the ~ operator does in pytorch for non bool + if x.dtype == dtypes.bool: + return (1 - x).cast(dtypes.bool) + return ( + x + 1 + ) * -1 # this seems to be what the ~ operator does in pytorch for non bool -def lengths_to_padding_mask(lens:Tensor) -> Tensor: - bsz, max_lens = lens.shape[0], lens.max().numpy().item() - mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens) - mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens) - return mask.cast(dtypes.bool) -def repeat_expand_2d_left(content, target_len): # content : [h, t] - src_len = content.shape[-1] - temp = np.arange(src_len+1) * target_len / src_len - current_pos, cols = 0, [] - for i in range(target_len): - if i >= temp[current_pos+1]: - current_pos += 1 - cols.append(content[:, current_pos]) - return Tensor.stack(cols).transpose(0, 1) +def lengths_to_padding_mask(lens: Tensor) -> Tensor: + bsz, max_lens = lens.shape[0], lens.max().numpy().item() + mask = Tensor.arange(max_lens).to(lens.device).reshape(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.reshape(bsz, 1).expand(-1, max_lens) + return mask.cast(dtypes.bool) + + +def repeat_expand_2d_left(content, target_len): # content : [h, t] + src_len = content.shape[-1] + temp = np.arange(src_len + 1) * target_len / src_len + current_pos, cols = 0, [] + for i in range(target_len): + if i >= temp[current_pos + 1]: + current_pos += 1 + cols.append(content[:, current_pos]) + return Tensor.stack(cols).transpose(0, 1) + def load_fairseq_cfg(checkpoint_path): - assert Path(checkpoint_path).is_file() - state = torch_load(checkpoint_path) - cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None - if cfg is None: raise RuntimeError(f"No cfg exist in state keys = {state.keys()}") - return HParams(**cfg) + assert Path(checkpoint_path).is_file() + state = torch_load(checkpoint_path) + cfg = state["cfg"] if ("cfg" in state and state["cfg"] is not None) else None + if cfg is None: + raise RuntimeError(f"No cfg exist in state keys = {state.keys()}") + return HParams(**cfg) + + +def load_checkpoint_enc( + checkpoint_path, model: ContentVec, optimizer=None, skip_list=[] +): + assert Path(checkpoint_path).is_file() + start_time = time.time() + checkpoint_dict = torch_load(checkpoint_path) + saved_state_dict = checkpoint_dict["model"] + weight_g, weight_v, parent = None, None, None + for key, v in saved_state_dict.items(): + if any(layer in key for layer in skip_list): + continue + try: + obj, skip = model, False + for k in key.split("."): + if k.isnumeric(): + obj = obj[int(k)] + elif isinstance(obj, dict): + obj = obj[k] + else: + if k in ["weight_g", "weight_v"]: + parent, skip = obj, True + if k == "weight_g": + weight_g = v + else: + weight_v = v + if not skip: + parent = obj + obj = getattr(obj, k) + if weight_g and weight_v: + setattr(obj, "weight_g", weight_g.numpy()) + setattr(obj, "weight_v", weight_v.numpy()) + obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0) + weight_g, weight_v, parent, skip = None, None, None, False + if not skip and obj.shape == v.shape: + if "feature_extractor" in key and ( + isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm) + ): # cast + obj.assign(v.to(obj.device).float()) + else: + obj.assign(v.to(obj.device)) + elif not skip: + logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}") + except Exception as e: + raise e + logging.info( + f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s" + ) + return model, optimizer -def load_checkpoint_enc(checkpoint_path, model: ContentVec, optimizer=None, skip_list=[]): - assert Path(checkpoint_path).is_file() - start_time = time.time() - checkpoint_dict = torch_load(checkpoint_path) - saved_state_dict = checkpoint_dict['model'] - weight_g, weight_v, parent = None, None, None - for key, v in saved_state_dict.items(): - if any(layer in key for layer in skip_list): continue - try: - obj, skip = model, False - for k in key.split('.'): - if k.isnumeric(): obj = obj[int(k)] - elif isinstance(obj, dict): obj = obj[k] - else: - if k in ["weight_g", "weight_v"]: - parent, skip = obj, True - if k == "weight_g": weight_g = v - else: weight_v = v - if not skip: - parent = obj - obj = getattr(obj, k) - if weight_g and weight_v: - setattr(obj, "weight_g", weight_g.numpy()) - setattr(obj, "weight_v", weight_v.numpy()) - obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0) - weight_g, weight_v, parent, skip = None, None, None, False - if not skip and obj.shape == v.shape: - if "feature_extractor" in key and (isinstance(parent, nn.GroupNorm) or isinstance(parent, nn.LayerNorm)): # cast - obj.assign(v.to(obj.device).float()) - else: - obj.assign(v.to(obj.device)) - elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}") - except Exception as e: raise e - logging.info(f"Loaded checkpoint '{checkpoint_path}' in {time.time() - start_time:.4f}s") - return model, optimizer def pad_array(arr, target_length): - current_length = arr.shape[0] - if current_length >= target_length: return arr - pad_width = target_length - current_length - pad_left = pad_width // 2 - pad_right = pad_width - pad_left - padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0)) - return padded_arr + current_length = arr.shape[0] + if current_length >= target_length: + return arr + pad_width = target_length - current_length + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + padded_arr = np.pad(arr, (pad_left, pad_right), "constant", constant_values=(0, 0)) + return padded_arr + def split_list_by_n(list_collection, n, pre=0): - for i in range(0, len(list_collection), n): - yield list_collection[i-pre if i-pre>=0 else i: i + n] + for i in range(0, len(list_collection), n): + yield list_collection[i - pre if i - pre >= 0 else i : i + n] + + +def get_sid(spk2id: HParams, speaker: str) -> Tensor: + speaker_id = spk2id[speaker] + if not speaker_id and type(speaker) is int: + if len(spk2id.__dict__) >= speaker: + speaker_id = speaker + if speaker_id is None: + raise RuntimeError(f"speaker={speaker} not in the speaker list") + return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0) -def get_sid(spk2id:HParams, speaker:str) -> Tensor: - speaker_id = spk2id[speaker] - if not speaker_id and type(speaker) is int: - if len(spk2id.__dict__) >= speaker: speaker_id = speaker - if speaker_id is None: raise RuntimeError(f"speaker={speaker} not in the speaker list") - return Tensor([int(speaker_id)], dtype=dtypes.int64).unsqueeze(0) def get_encoder(ssl_dim) -> Type[SpeechEncoder]: - if ssl_dim == 256: return ContentVec256L9 - if ssl_dim == 768: return ContentVec768L12 + if ssl_dim == 256: + return ContentVec256L9 + if ssl_dim == 768: + return ContentVec768L12 + ######################################################################################### # CODE: https://github.com/svc-develop-team/so-vits-svc @@ -551,120 +1155,228 @@ def get_encoder(ssl_dim) -> Type[SpeechEncoder]: # python3 examples/so_vits_svc.py --model saul_goodman ######################################################################################### SO_VITS_SVC_PATH = Path(__file__).parents[1] / "weights/So-VITS-SVC" -VITS_MODELS = { # config_path, weights_path, config_url, weights_url - "saul_goodman" : (SO_VITS_SVC_PATH / "config_saul_gman.json", SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth"), - "drake" : (SO_VITS_SVC_PATH / "config_drake.json", SO_VITS_SVC_PATH / "pretrained_drake.pth", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth"), - "cartman" : (SO_VITS_SVC_PATH / "config_cartman.json", SO_VITS_SVC_PATH / "pretrained_cartman.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth"), - "tf2spy" : (SO_VITS_SVC_PATH / "config_tf2spy.json", SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth"), - "tf2heavy" : (SO_VITS_SVC_PATH / "config_tf2heavy.json", SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth"), - "lady_gaga" : (SO_VITS_SVC_PATH / "config_gaga.json", SO_VITS_SVC_PATH / "pretrained_gaga.pth", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth") +VITS_MODELS = { # config_path, weights_path, config_url, weights_url + "saul_goodman": ( + SO_VITS_SVC_PATH / "config_saul_gman.json", + SO_VITS_SVC_PATH / "pretrained_saul_gman.pth", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/config.json", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/Saul_Goodman_80000/G_80000.pth", + ), + "drake": ( + SO_VITS_SVC_PATH / "config_drake.json", + SO_VITS_SVC_PATH / "pretrained_drake.pth", + "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/config_aubrey.json", + "https://huggingface.co/jaspa/so-vits-svc/resolve/main/aubrey/pretrained_aubrey.pth", + ), + "cartman": ( + SO_VITS_SVC_PATH / "config_cartman.json", + SO_VITS_SVC_PATH / "pretrained_cartman.pth", + "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/config.json", + "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/EricCartman/G_10200.pth", + ), + "tf2spy": ( + SO_VITS_SVC_PATH / "config_tf2spy.json", + SO_VITS_SVC_PATH / "pretrained_tf2spy.pth", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/config.json", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_spy_60k/G_60000.pth", + ), + "tf2heavy": ( + SO_VITS_SVC_PATH / "config_tf2heavy.json", + SO_VITS_SVC_PATH / "pretrained_tf2heavy.pth", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/config.json", + "https://huggingface.co/Amo/so-vits-svc-4.0_GA/resolve/main/ModelsFolder/TF2_heavy_100k/G_100000.pth", + ), + "lady_gaga": ( + SO_VITS_SVC_PATH / "config_gaga.json", + SO_VITS_SVC_PATH / "pretrained_gaga.pth", + "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/config.json", + "https://huggingface.co/marcoc2/so-vits-svc-4.0-models/resolve/main/LadyGaga/G_14400.pth", + ), } -ENCODER_MODELS = { # weights_path, weights_url - "contentvec": (SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt") +ENCODER_MODELS = { # weights_path, weights_url + "contentvec": ( + SO_VITS_SVC_PATH / "contentvec_checkpoint.pt", + "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt", + ) } ENCODER_MODEL = "contentvec" -DEMO_PATH, DEMO_URL = Path(__file__).parents[1] / "temp/LJ037-0171.wav", "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav" -if __name__=="__main__": - logging.basicConfig(stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG)) - parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", default=None, help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", required=True) - parser.add_argument("-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file") - parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.") - parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.") - parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.") - parser.add_argument("--speaker", default=None, help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.") - parser.add_argument("--noise_scale", default=0.4) - parser.add_argument("--tran", default=0.0, help="Pitch shift, supports positive and negative (semitone) values. Default 0.0") - parser.add_argument("--pad_seconds", default=0.5) - parser.add_argument("--lg_num", default=0.0) - parser.add_argument("--clip_seconds", default=0.0) - parser.add_argument("--slice_db", default=-40) - args = parser.parse_args() +DEMO_PATH, DEMO_URL = ( + Path(__file__).parents[1] / "temp/LJ037-0171.wav", + "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav", +) +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, level=(logging.INFO if DEBUG < 1 else logging.DEBUG) + ) + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + default=None, + help=f"Specify the model to use. All supported models: {VITS_MODELS.keys()}", + required=True, + ) + parser.add_argument( + "-f", "--file", default=DEMO_PATH, help=f"Specify the path of the input file" + ) + parser.add_argument( + "--out_dir", + default=str(Path(__file__).parents[1] / "temp"), + help="Specify the output path.", + ) + parser.add_argument( + "--out_path", + default=None, + help="Specify the full output path. Overrides the --out_dir and --name parameter.", + ) + parser.add_argument( + "--base_name", + default="test", + help="Specify the base of the output file name. Default is 'test'.", + ) + parser.add_argument( + "--speaker", + default=None, + help="If not specified, the first available speaker is chosen. Usually there is only one speaker per model.", + ) + parser.add_argument("--noise_scale", default=0.4) + parser.add_argument( + "--tran", + default=0.0, + help="Pitch shift, supports positive and negative (semitone) values. Default 0.0", + ) + parser.add_argument("--pad_seconds", default=0.5) + parser.add_argument("--lg_num", default=0.0) + parser.add_argument("--clip_seconds", default=0.0) + parser.add_argument("--slice_db", default=-40) + args = parser.parse_args() - vits_model = args.model - encoder_location, vits_location = ENCODER_MODELS[ENCODER_MODEL], VITS_MODELS[vits_model] + vits_model = args.model + encoder_location, vits_location = ( + ENCODER_MODELS[ENCODER_MODEL], + VITS_MODELS[vits_model], + ) - Tensor.no_grad, Tensor.training = True, False - # Get Synthesizer and ContentVec - net_g, hps = Synthesizer.load_from_pretrained(vits_location[0], vits_location[2], vits_location[1], vits_location[3]) - Encoder = get_encoder(hps.model.ssl_dim) - encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1]) + Tensor.no_grad, Tensor.training = True, False + # Get Synthesizer and ContentVec + net_g, hps = Synthesizer.load_from_pretrained( + vits_location[0], vits_location[2], vits_location[1], vits_location[3] + ) + Encoder = get_encoder(hps.model.ssl_dim) + encoder = Encoder.load_from_pretrained(encoder_location[0], encoder_location[1]) - # model config args - target_sample, spk2id, hop_length, target_sample = hps.data.sampling_rate, hps.spk, hps.data.hop_length, hps.data.sampling_rate - vol_embedding = hps.model.vol_embedding if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None else False + # model config args + target_sample, spk2id, hop_length, target_sample = ( + hps.data.sampling_rate, + hps.spk, + hps.data.hop_length, + hps.data.sampling_rate, + ) + vol_embedding = ( + hps.model.vol_embedding + if hasattr(hps.data, "vol_embedding") and hps.model.vol_embedding is not None + else False + ) - # args - slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = args.slice_db, args.clip_seconds, args.lg_num, args.pad_seconds, args.tran, args.noise_scale, args.file - speaker = args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0] + # args + slice_db, clip_seconds, lg_num, pad_seconds, tran, noise_scale, audio_path = ( + args.slice_db, + args.clip_seconds, + args.lg_num, + args.pad_seconds, + args.tran, + args.noise_scale, + args.file, + ) + speaker = ( + args.speaker if args.speaker is not None else list(hps.spk.__dict__.keys())[0] + ) - ### Loading audio and slicing ### - if audio_path == DEMO_PATH: download_if_not_present(DEMO_PATH, DEMO_URL) - assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav" - chunks = preprocess.cut(audio_path, db_thresh=slice_db) - audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks) + ### Loading audio and slicing ### + if audio_path == DEMO_PATH: + download_if_not_present(DEMO_PATH, DEMO_URL) + assert Path(audio_path).is_file() and Path(audio_path).suffix == ".wav" + chunks = preprocess.cut(audio_path, db_thresh=slice_db) + audio_data, audio_sr = preprocess.chunks2audio(audio_path, chunks) - per_size = int(clip_seconds * audio_sr) - lg_size = int(lg_num * audio_sr) + per_size = int(clip_seconds * audio_sr) + lg_size = int(lg_num * audio_sr) - ### Infer per slice ### - global_frame = 0 - audio = [] - for (slice_tag, data) in audio_data: - print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====") - length = int(np.ceil(len(data) / audio_sr * target_sample)) + ### Infer per slice ### + global_frame = 0 + audio = [] + for slice_tag, data in audio_data: + print(f"\n====segment start, {round(len(data) / audio_sr, 3)}s====") + length = int(np.ceil(len(data) / audio_sr * target_sample)) - if slice_tag: - print("empty segment") - _audio = np.zeros(length) - audio.extend(list(pad_array(_audio, length))) - global_frame += length // hop_length - continue + if slice_tag: + print("empty segment") + _audio = np.zeros(length) + audio.extend(list(pad_array(_audio, length))) + global_frame += length // hop_length + continue - datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size) + datas = [data] if per_size == 0 else split_list_by_n(data, per_size, lg_size) - for k, dat in enumerate(datas): - per_length = int(np.ceil(len(dat) / audio_sr * target_sample)) if clip_seconds!=0 else length - pad_len = int(audio_sr * pad_seconds) - dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) - raw_path = io.BytesIO() - soundfile.write(raw_path, dat, audio_sr, format="wav") - raw_path.seek(0) + for k, dat in enumerate(datas): + per_length = ( + int(np.ceil(len(dat) / audio_sr * target_sample)) + if clip_seconds != 0 + else length + ) + pad_len = int(audio_sr * pad_seconds) + dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) + raw_path = io.BytesIO() + soundfile.write(raw_path, dat, audio_sr, format="wav") + raw_path.seek(0) - ### Infer START ### - wav, sr = preprocess.load_audiofile(raw_path) - wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0] - wav16k, f0, uv = preprocess.get_unit_f0(wav, tran, hop_length, target_sample) - sid = get_sid(spk2id, speaker) - n_frames = f0.shape[1] + ### Infer START ### + wav, sr = preprocess.load_audiofile(raw_path) + wav = preprocess.sinc_interp_resample(wav, sr, target_sample)[0] + wav16k, f0, uv = preprocess.get_unit_f0( + wav, tran, hop_length, target_sample + ) + sid = get_sid(spk2id, speaker) + n_frames = f0.shape[1] - # ContentVec infer - start = time.time() - c = encoder.encode(wav16k) - c = repeat_expand_2d_left(c.squeeze(0).realize(), f0.shape[1]) # interpolate speech encoding to match f0 - c = c.unsqueeze(0).realize() - enc_time = time.time() - start + # ContentVec infer + start = time.time() + c = encoder.encode(wav16k) + c = repeat_expand_2d_left( + c.squeeze(0).realize(), f0.shape[1] + ) # interpolate speech encoding to match f0 + c = c.unsqueeze(0).realize() + enc_time = time.time() - start - # VITS infer - vits_start = time.time() - out_audio, f0 = net_g.infer(c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None) - out_audio = out_audio[0,0].float().realize() - vits_time = time.time() - vits_start + # VITS infer + vits_start = time.time() + out_audio, f0 = net_g.infer( + c, f0=f0, uv=uv, g=sid, noise_scale=noise_scale, vol=None + ) + out_audio = out_audio[0, 0].float().realize() + vits_time = time.time() - vits_start - infer_time = time.time() - start - logging.info("total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format(infer_time, enc_time, vits_time)) - ### Infer END ### + infer_time = time.time() - start + logging.info( + "total infer time:{:.2f}s, speech_enc time:{:.2f}s, vits time:{:.2f}s".format( + infer_time, enc_time, vits_time + ) + ) + ### Infer END ### - out_sr, out_frame = out_audio.shape[-1], n_frames - global_frame += out_frame - _audio = out_audio.numpy() - pad_len = int(target_sample * pad_seconds) - _audio = _audio[pad_len:-pad_len] - _audio = pad_array(_audio, per_length) - audio.extend(list(_audio)) + out_sr, out_frame = out_audio.shape[-1], n_frames + global_frame += out_frame + _audio = out_audio.numpy() + pad_len = int(target_sample * pad_seconds) + _audio = _audio[pad_len:-pad_len] + _audio = pad_array(_audio, per_length) + audio.extend(list(_audio)) - audio = np.array(audio) - out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav") - out_path.parent.mkdir(parents=True, exist_ok=True) - soundfile.write(out_path, audio, target_sample, format="flac") - logging.info(f"Saved audio output to {out_path}") + audio = np.array(audio) + out_path = Path( + args.out_path + or Path(args.out_dir) / f"{args.model}{f'_spk_{speaker}'}_{args.base_name}.wav" + ) + out_path.parent.mkdir(parents=True, exist_ok=True) + soundfile.write(out_path, audio, target_sample, format="flac") + logging.info(f"Saved audio output to {out_path}") diff --git a/examples/sovits_helpers/preprocess.py b/examples/sovits_helpers/preprocess.py index 99bb0a182..f4f0f4416 100644 --- a/examples/sovits_helpers/preprocess.py +++ b/examples/sovits_helpers/preprocess.py @@ -7,199 +7,369 @@ import soundfile import numpy as np import parselmouth + class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/ - def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): - self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm" - def interpolate_f0(self,f0): - vuv_vector = np.zeros_like(f0, dtype=np.float32) - vuv_vector[f0 > 0.0] = 1.0 - vuv_vector[f0 <= 0.0] = 0.0 - nzindex = np.nonzero(f0)[0] - data = f0[nzindex] - nzindex = nzindex.astype(np.float32) - time_org = self.hop_length / self.sampling_rate * nzindex - time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate - if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector - if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector - f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) - return f0,vuv_vector - def compute_f0(self,wav,p_len=None): - x = wav - if p_len is None: p_len = x.shape[0]//self.hop_length - else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" - time_step = self.hop_length / self.sampling_rate * 1000 - f0 = parselmouth.Sound(x, self.sampling_rate) \ - .to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \ - .selected_array['frequency'] - pad_size=(p_len - len(f0) + 1) // 2 - if(pad_size>0 or p_len - len(f0) - pad_size>0): - f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') - f0,uv = self.interpolate_f0(f0) - return f0 - def compute_f0_uv(self,wav,p_len=None): - x = wav - if p_len is None: p_len = x.shape[0]//self.hop_length - else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" - time_step = self.hop_length / self.sampling_rate * 1000 - f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( - time_step=time_step / 1000, voicing_threshold=0.6, - pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] - pad_size=(p_len - len(f0) + 1) // 2 - if(pad_size>0 or p_len - len(f0) - pad_size>0): - f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') - f0,uv = self.interpolate_f0(f0) - return f0,uv + def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100): + self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = ( + hop_length, + f0_min, + f0_max, + sampling_rate, + "pm", + ) + + def interpolate_f0(self, f0): + vuv_vector = np.zeros_like(f0, dtype=np.float32) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + nzindex = np.nonzero(f0)[0] + data = f0[nzindex] + nzindex = nzindex.astype(np.float32) + time_org = self.hop_length / self.sampling_rate * nzindex + time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate + if data.shape[0] <= 0: + return np.zeros(f0.shape[0], dtype=np.float32), vuv_vector + if data.shape[0] == 1: + return np.ones(f0.shape[0], dtype=np.float32) * f0[0], vuv_vector + f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) + return f0, vuv_vector + + def compute_f0(self, wav, p_len=None): + x = wav + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + time_step = self.hop_length / self.sampling_rate * 1000 + f0 = ( + parselmouth.Sound(x, self.sampling_rate) + .to_pitch_ac( + time_step=time_step / 1000, + voicing_threshold=0.6, + pitch_floor=self.f0_min, + pitch_ceiling=self.f0_max, + ) + .selected_array["frequency"] + ) + pad_size = (p_len - len(f0) + 1) // 2 + if pad_size > 0 or p_len - len(f0) - pad_size > 0: + f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant") + f0, uv = self.interpolate_f0(f0) + return f0 + + def compute_f0_uv(self, wav, p_len=None): + x = wav + if p_len is None: + p_len = x.shape[0] // self.hop_length + else: + assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" + time_step = self.hop_length / self.sampling_rate * 1000 + f0 = ( + parselmouth.Sound(x, self.sampling_rate) + .to_pitch_ac( + time_step=time_step / 1000, + voicing_threshold=0.6, + pitch_floor=self.f0_min, + pitch_ceiling=self.f0_max, + ) + .selected_array["frequency"] + ) + pad_size = (p_len - len(f0) + 1) // 2 + if pad_size > 0 or p_len - len(f0) - pad_size > 0: + f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant") + f0, uv = self.interpolate_f0(f0) + return f0, uv + class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/ - def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000): - if not min_length >= min_interval >= hop_size: - raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') - if not max_sil_kept >= hop_size: - raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') - min_interval = sr * min_interval / 1000 - self.threshold = 10 ** (threshold / 20.) - self.hop_size = round(sr * hop_size / 1000) - self.win_size = min(round(min_interval), 4 * self.hop_size) - self.min_length = round(sr * min_length / 1000 / self.hop_size) - self.min_interval = round(min_interval / self.hop_size) - self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) - def _apply_slice(self, waveform, begin, end): - if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] - else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] - def slice(self, waveform): - samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform - if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} - rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) - sil_tags, silence_start, clip_start = [], None, 0 - for i, rms in enumerate(rms_list): - if rms < self.threshold: # Keep looping while frame is silent. - if silence_start is None: # Record start of silent frames. - silence_start = i - continue - if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded. - # Clear recorded silence start if interval is not enough or clip is too short - is_leading_silence = silence_start == 0 and i > self.max_sil_kept - need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length - if not is_leading_silence and not need_slice_middle: - silence_start = None - continue - if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed. - pos = rms_list[silence_start: i + 1].argmin() + silence_start - sil_tags.append((0, pos) if silence_start == 0 else (pos, pos)) - clip_start = pos - elif i - silence_start <= self.max_sil_kept * 2: - pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() - pos += i - self.max_sil_kept - pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept - if silence_start == 0: - sil_tags.append((0, pos_r)) - clip_start = pos_r + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] else: - sil_tags.append((min(pos_l, pos), max(pos_r, pos))) - clip_start = max(pos_r, pos) - else: - pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start - pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept - sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r)) - clip_start = pos_r - silence_start = None - total_frames = rms_list.shape[0] - if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence. - silence_end = min(total_frames, silence_start + self.max_sil_kept) - pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start - sil_tags.append((pos, total_frames + 1)) - if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices. - chunks = [] - if sil_tags[0][0]: - chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) - for i in range(0, len(sil_tags)): - if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) - chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) - if sil_tags[-1][1] * self.hop_size < len(waveform): - chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) - chunk_dict = {} - for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i] - return chunk_dict + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + def slice(self, waveform): + samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform + if samples.shape[0] <= self.min_length: + return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} + rms_list = librosa.feature.rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags, silence_start, clip_start = [], None, 0 + for i, rms in enumerate(rms_list): + if rms < self.threshold: # Keep looping while frame is silent. + if silence_start is None: # Record start of silent frames. + silence_start = i + continue + if silence_start is None: + continue # Keep looping while frame is not silent and silence start has not been recorded. + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + if ( + i - silence_start <= self.max_sil_kept + ): # Need slicing. Record the range of silent frames to be removed. + pos = rms_list[silence_start : i + 1].argmin() + silence_start + sil_tags.append((0, pos) if silence_start == 0 else (pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r)) + clip_start = pos_r + silence_start = None + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): # Deal with trailing silence. + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + if len(sil_tags) == 0: + return { + "0": {"slice": False, "split_time": f"0,{len(waveform)}"} + } # Apply and return slices. + chunks = [] + if sil_tags[0][0]: + chunks.append( + { + "slice": False, + "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}", + } + ) + for i in range(0, len(sil_tags)): + if i: + chunks.append( + { + "slice": False, + "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}", + } + ) + chunks.append( + { + "slice": True, + "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}", + } + ) + if sil_tags[-1][1] * self.hop_size < len(waveform): + chunks.append( + { + "slice": False, + "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}", + } + ) + chunk_dict = {} + for i in range(len(chunks)): + chunk_dict[str(i)] = chunks[i] + return chunk_dict + # sinc_interp_hann audio resampling class Resample: - def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None): - self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta - self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) - self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None) - def __call__(self, waveform:Tensor) -> Tensor: - if self.orig_freq == self.new_freq: return waveform - return self._apply_sinc_resample_kernel(waveform) - def _apply_sinc_resample_kernel(self, waveform:Tensor): - if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.") - orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd) - shape = waveform.shape - waveform = waveform.reshape(-1, shape[-1]) # pack batch - num_wavs, length = waveform.shape - target_length = int(math.ceil(new_freq * length / orig_freq)) - waveform = waveform.pad2d((self.width, self.width + orig_freq)) - resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq) - resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - resampled = resampled[..., :target_length] - resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch - return resampled - def _get_sinc_resample_kernel(self, dtype=None): - orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd) - if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.") - base_freq = min(orig_freq, new_freq) - base_freq *= self.rolloff - width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq) - idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq - t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx - t *= base_freq - t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width) - window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2 - t *= math.pi - scale = base_freq / orig_freq - kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t) - kernels *= window * scale - if dtype is None: kernels = kernels.cast(dtype=dtypes.float32) - return kernels, width + def __init__( + self, + orig_freq: int = 16000, + new_freq: int = 16000, + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + beta: Optional[float] = None, + dtype: Optional[dtypes] = None, + ): + ( + self.orig_freq, + self.new_freq, + self.lowpass_filter_width, + self.rolloff, + self.beta, + ) = (orig_freq, new_freq, lowpass_filter_width, rolloff, beta) + self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) + self.kernel, self.width = ( + self._get_sinc_resample_kernel(dtype) + if self.orig_freq != self.new_freq + else (None, None) + ) + + def __call__(self, waveform: Tensor) -> Tensor: + if self.orig_freq == self.new_freq: + return waveform + return self._apply_sinc_resample_kernel(waveform) + + def _apply_sinc_resample_kernel(self, waveform: Tensor): + if not waveform.is_floating_point(): + raise TypeError( + f"Waveform tensor expected to be of type float, but received {waveform.dtype}." + ) + orig_freq, new_freq = (int(self.orig_freq) // self.gcd), ( + int(self.new_freq) // self.gcd + ) + shape = waveform.shape + waveform = waveform.reshape(-1, shape[-1]) # pack batch + num_wavs, length = waveform.shape + target_length = int(math.ceil(new_freq * length / orig_freq)) + waveform = waveform.pad2d((self.width, self.width + orig_freq)) + resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + resampled = resampled[..., :target_length] + resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch + return resampled + + def _get_sinc_resample_kernel(self, dtype=None): + orig_freq, new_freq = (int(self.orig_freq) // self.gcd), ( + int(self.new_freq) // self.gcd + ) + if self.lowpass_filter_width <= 0: + raise ValueError("Low pass filter width should be positive.") + base_freq = min(orig_freq, new_freq) + base_freq *= self.rolloff + width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq) + idx = ( + Tensor.arange( + -width, + width + orig_freq, + dtype=(dtype if dtype is not None else dtypes.float32), + )[None, None] + / orig_freq + ) + t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx + t *= base_freq + t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width) + window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2 + t *= math.pi + scale = base_freq / orig_freq + kernels = Tensor.where( + t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t + ) + kernels *= window * scale + if dtype is None: + kernels = kernels.cast(dtype=dtypes.float32) + return kernels, width + + +def sinc_interp_resample( + x: Tensor, + orig_freq: int = 16000, + new_freq: int = 1600, + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + beta: Optional[float] = None, +): + resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype) + return resamp(x) -def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None): - resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype) - return resamp(x) def cut(audio_path, db_thresh=-30, min_len=5000): - audio, sr = librosa.load(audio_path, sr=None) - slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len) - chunks = slicer.slice(audio) - return chunks + audio, sr = librosa.load(audio_path, sr=None) + slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len) + chunks = slicer.slice(audio) + return chunks + def chunks2audio(audio_path, chunks): - chunks = dict(chunks) - audio, sr = load_audiofile(audio_path) - if len(audio.shape) == 2 and audio.shape[1] >= 2: - audio = audio.mean(0).unsqueeze(0) - audio = audio.numpy()[0] - result = [] - for k, v in chunks.items(): - tag = v["split_time"].split(",") - if tag[0] != tag[1]: - result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) - return result, sr + chunks = dict(chunks) + audio, sr = load_audiofile(audio_path) + if len(audio.shape) == 2 and audio.shape[1] >= 2: + audio = audio.mean(0).unsqueeze(0) + audio = audio.numpy()[0] + result = [] + for k, v in chunks.items(): + tag = v["split_time"].split(",") + if tag[0] != tag[1]: + result.append((v["slice"], audio[int(tag[0]) : int(tag[1])])) + return result, sr -def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True): - with soundfile.SoundFile(filepath, "r") as file_: - frames = file_._prepare_read(frame_offset, None, num_frames) - waveform = file_.read(frames, "float32", always_2d=True) - sample_rate = file_.samplerate - waveform = Tensor(waveform) - if channels_first: waveform = waveform.transpose(0, 1) - return waveform, sample_rate -def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]: - f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample) - f0, uv = f0_predictor.compute_f0_uv(wav.numpy()) - if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected") - f0 = Tensor(f0.astype(np.float32)).float() - f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0) - uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0) - wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0] - return wav16k.realize(), f0.realize(), uv.realize() +def load_audiofile( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + channels_first: bool = True, +): + with soundfile.SoundFile(filepath, "r") as file_: + frames = file_._prepare_read(frame_offset, None, num_frames) + waveform = file_.read(frames, "float32", always_2d=True) + sample_rate = file_.samplerate + waveform = Tensor(waveform) + if channels_first: + waveform = waveform.transpose(0, 1) + return waveform, sample_rate + + +def get_unit_f0( + wav: Tensor, tran, hop_length, target_sample, f0_filter=False +) -> Tuple[Tensor, Tensor, Tensor]: + f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample) + f0, uv = f0_predictor.compute_f0_uv(wav.numpy()) + if f0_filter and sum(f0) == 0: + raise RuntimeError("No voice detected") + f0 = Tensor(f0.astype(np.float32)).float() + f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0) + uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0) + wav16k = sinc_interp_resample(wav[None, :], target_sample, 16000)[0] + return wav16k.realize(), f0.realize(), uv.realize() diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 5e2dd2bbb..417b7cbc7 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -14,547 +14,676 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict from tinygrad.jit import TinyJit + class AttnBlock: - def __init__(self, in_channels): - self.norm = GroupNorm(32, in_channels) - self.q = Conv2d(in_channels, in_channels, 1) - self.k = Conv2d(in_channels, in_channels, 1) - self.v = Conv2d(in_channels, in_channels, 1) - self.proj_out = Conv2d(in_channels, in_channels, 1) + def __init__(self, in_channels): + self.norm = GroupNorm(32, in_channels) + self.q = Conv2d(in_channels, in_channels, 1) + self.k = Conv2d(in_channels, in_channels, 1) + self.v = Conv2d(in_channels, in_channels, 1) + self.proj_out = Conv2d(in_channels, in_channels, 1) - # copied from AttnBlock in ldm repo - def __call__(self, x): - h_ = self.norm(x) - q,k,v = self.q(h_), self.k(h_), self.v(h_) + # copied from AttnBlock in ldm repo + def __call__(self, x): + h_ = self.norm(x) + q, k, v = self.q(h_), self.k(h_), self.v(h_) + + # compute attention + b, c, h, w = q.shape + q, k, v = [x.reshape(b, c, h * w).transpose(1, 2) for x in (q, k, v)] + h_ = ( + Tensor.scaled_dot_product_attention(q, k, v) + .transpose(1, 2) + .reshape(b, c, h, w) + ) + return x + self.proj_out(h_) - # compute attention - b,c,h,w = q.shape - q,k,v = [x.reshape(b,c,h*w).transpose(1,2) for x in (q,k,v)] - h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w) - return x + self.proj_out(h_) class ResnetBlock: - def __init__(self, in_channels, out_channels=None): - self.norm1 = GroupNorm(32, in_channels) - self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1) - self.norm2 = GroupNorm(32, out_channels) - self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1) - self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x + def __init__(self, in_channels, out_channels=None): + self.norm1 = GroupNorm(32, in_channels) + self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1) + self.norm2 = GroupNorm(32, out_channels) + self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1) + self.nin_shortcut = ( + Conv2d(in_channels, out_channels, 1) + if in_channels != out_channels + else lambda x: x + ) + + def __call__(self, x): + h = self.conv1(self.norm1(x).swish()) + h = self.conv2(self.norm2(h).swish()) + return self.nin_shortcut(x) + h - def __call__(self, x): - h = self.conv1(self.norm1(x).swish()) - h = self.conv2(self.norm2(h).swish()) - return self.nin_shortcut(x) + h class Mid: - def __init__(self, block_in): - self.block_1 = ResnetBlock(block_in, block_in) - self.attn_1 = AttnBlock(block_in) - self.block_2 = ResnetBlock(block_in, block_in) + def __init__(self, block_in): + self.block_1 = ResnetBlock(block_in, block_in) + self.attn_1 = AttnBlock(block_in) + self.block_2 = ResnetBlock(block_in, block_in) + + def __call__(self, x): + return x.sequential([self.block_1, self.attn_1, self.block_2]) - def __call__(self, x): - return x.sequential([self.block_1, self.attn_1, self.block_2]) class Decoder: - def __init__(self): - sz = [(128, 256), (256, 512), (512, 512), (512, 512)] - self.conv_in = Conv2d(4,512,3, padding=1) - self.mid = Mid(512) + def __init__(self): + sz = [(128, 256), (256, 512), (512, 512), (512, 512)] + self.conv_in = Conv2d(4, 512, 3, padding=1) + self.mid = Mid(512) - arr = [] - for i,s in enumerate(sz): - arr.append({"block": - [ResnetBlock(s[1], s[0]), - ResnetBlock(s[0], s[0]), - ResnetBlock(s[0], s[0])]}) - if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)} - self.up = arr + arr = [] + for i, s in enumerate(sz): + arr.append( + { + "block": [ + ResnetBlock(s[1], s[0]), + ResnetBlock(s[0], s[0]), + ResnetBlock(s[0], s[0]), + ] + } + ) + if i != 0: + arr[-1]["upsample"] = {"conv": Conv2d(s[0], s[0], 3, padding=1)} + self.up = arr - self.norm_out = GroupNorm(32, 128) - self.conv_out = Conv2d(128, 3, 3, padding=1) + self.norm_out = GroupNorm(32, 128) + self.conv_out = Conv2d(128, 3, 3, padding=1) - def __call__(self, x): - x = self.conv_in(x) - x = self.mid(x) + def __call__(self, x): + x = self.conv_in(x) + x = self.mid(x) - for l in self.up[::-1]: - print("decode", x.shape) - for b in l['block']: x = b(x) - if 'upsample' in l: - # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ? - bs,c,py,px = x.shape - x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) - x = l['upsample']['conv'](x) - x.realize() + for l in self.up[::-1]: + print("decode", x.shape) + for b in l["block"]: + x = b(x) + if "upsample" in l: + # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ? + bs, c, py, px = x.shape + x = ( + x.reshape(bs, c, py, 1, px, 1) + .expand(bs, c, py, 2, px, 2) + .reshape(bs, c, py * 2, px * 2) + ) + x = l["upsample"]["conv"](x) + x.realize() + + return self.conv_out(self.norm_out(x).swish()) - return self.conv_out(self.norm_out(x).swish()) class Encoder: - def __init__(self): - sz = [(128, 128), (128, 256), (256, 512), (512, 512)] - self.conv_in = Conv2d(3,128,3, padding=1) + def __init__(self): + sz = [(128, 128), (128, 256), (256, 512), (512, 512)] + self.conv_in = Conv2d(3, 128, 3, padding=1) - arr = [] - for i,s in enumerate(sz): - arr.append({"block": - [ResnetBlock(s[0], s[1]), - ResnetBlock(s[1], s[1])]}) - if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))} - self.down = arr + arr = [] + for i, s in enumerate(sz): + arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]}) + if i != 3: + arr[-1]["downsample"] = { + "conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1)) + } + self.down = arr - self.mid = Mid(512) - self.norm_out = GroupNorm(32, 512) - self.conv_out = Conv2d(512, 8, 3, padding=1) + self.mid = Mid(512) + self.norm_out = GroupNorm(32, 512) + self.conv_out = Conv2d(512, 8, 3, padding=1) - def __call__(self, x): - x = self.conv_in(x) + def __call__(self, x): + x = self.conv_in(x) - for l in self.down: - print("encode", x.shape) - for b in l['block']: x = b(x) - if 'downsample' in l: x = l['downsample']['conv'](x) + for l in self.down: + print("encode", x.shape) + for b in l["block"]: + x = b(x) + if "downsample" in l: + x = l["downsample"]["conv"](x) + + x = self.mid(x) + return self.conv_out(self.norm_out(x).swish()) - x = self.mid(x) - return self.conv_out(self.norm_out(x).swish()) class AutoencoderKL: - def __init__(self): - self.encoder = Encoder() - self.decoder = Decoder() - self.quant_conv = Conv2d(8, 8, 1) - self.post_quant_conv = Conv2d(4, 4, 1) + def __init__(self): + self.encoder = Encoder() + self.decoder = Decoder() + self.quant_conv = Conv2d(8, 8, 1) + self.post_quant_conv = Conv2d(4, 4, 1) + + def __call__(self, x): + latent = self.encoder(x) + latent = self.quant_conv(latent) + latent = latent[:, 0:4] # only the means + print("latent", latent.shape) + latent = self.post_quant_conv(latent) + return self.decoder(latent) - def __call__(self, x): - latent = self.encoder(x) - latent = self.quant_conv(latent) - latent = latent[:, 0:4] # only the means - print("latent", latent.shape) - latent = self.post_quant_conv(latent) - return self.decoder(latent) # not to be confused with ResnetBlock class ResBlock: - def __init__(self, channels, emb_channels, out_channels): - self.in_layers = [ - GroupNorm(32, channels), - Tensor.silu, - Conv2d(channels, out_channels, 3, padding=1) - ] - self.emb_layers = [ - Tensor.silu, - Linear(emb_channels, out_channels) - ] - self.out_layers = [ - GroupNorm(32, out_channels), - Tensor.silu, - lambda x: x, # needed for weights loading code to work - Conv2d(out_channels, out_channels, 3, padding=1) - ] - self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x + def __init__(self, channels, emb_channels, out_channels): + self.in_layers = [ + GroupNorm(32, channels), + Tensor.silu, + Conv2d(channels, out_channels, 3, padding=1), + ] + self.emb_layers = [Tensor.silu, Linear(emb_channels, out_channels)] + self.out_layers = [ + GroupNorm(32, out_channels), + Tensor.silu, + lambda x: x, # needed for weights loading code to work + Conv2d(out_channels, out_channels, 3, padding=1), + ] + self.skip_connection = ( + Conv2d(channels, out_channels, 1) + if channels != out_channels + else lambda x: x + ) + + def __call__(self, x, emb): + h = x.sequential(self.in_layers) + emb_out = emb.sequential(self.emb_layers) + h = h + emb_out.reshape(*emb_out.shape, 1, 1) + h = h.sequential(self.out_layers) + ret = self.skip_connection(x) + h + return ret - def __call__(self, x, emb): - h = x.sequential(self.in_layers) - emb_out = emb.sequential(self.emb_layers) - h = h + emb_out.reshape(*emb_out.shape, 1, 1) - h = h.sequential(self.out_layers) - ret = self.skip_connection(x) + h - return ret class CrossAttention: - def __init__(self, query_dim, context_dim, n_heads, d_head): - self.to_q = Linear(query_dim, n_heads*d_head, bias=False) - self.to_k = Linear(context_dim, n_heads*d_head, bias=False) - self.to_v = Linear(context_dim, n_heads*d_head, bias=False) - self.num_heads = n_heads - self.head_size = d_head - self.to_out = [Linear(n_heads*d_head, query_dim)] + def __init__(self, query_dim, context_dim, n_heads, d_head): + self.to_q = Linear(query_dim, n_heads * d_head, bias=False) + self.to_k = Linear(context_dim, n_heads * d_head, bias=False) + self.to_v = Linear(context_dim, n_heads * d_head, bias=False) + self.num_heads = n_heads + self.head_size = d_head + self.to_out = [Linear(n_heads * d_head, query_dim)] + + def __call__(self, x, context=None): + context = x if context is None else context + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + q, k, v = [ + y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1, 2) + for y in (q, k, v) + ] + attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) + return h_.sequential(self.to_out) - def __call__(self, x, context=None): - context = x if context is None else context - q,k,v = self.to_q(x), self.to_k(context), self.to_v(context) - q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] - attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) - h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) - return h_.sequential(self.to_out) class GEGLU: - def __init__(self, dim_in, dim_out): - self.proj = Linear(dim_in, dim_out * 2) - self.dim_out = dim_out + def __init__(self, dim_in, dim_out): + self.proj = Linear(dim_in, dim_out * 2) + self.dim_out = dim_out + + def __call__(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * gate.gelu() - def __call__(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * gate.gelu() class FeedForward: - def __init__(self, dim, mult=4): - self.net = [ - GEGLU(dim, dim*mult), - lambda x: x, # needed for weights loading code to work - Linear(dim*mult, dim) - ] + def __init__(self, dim, mult=4): + self.net = [ + GEGLU(dim, dim * mult), + lambda x: x, # needed for weights loading code to work + Linear(dim * mult, dim), + ] + + def __call__(self, x): + return x.sequential(self.net) - def __call__(self, x): - return x.sequential(self.net) class BasicTransformerBlock: - def __init__(self, dim, context_dim, n_heads, d_head): - self.attn1 = CrossAttention(dim, dim, n_heads, d_head) - self.ff = FeedForward(dim) - self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head) - self.norm1 = LayerNorm(dim) - self.norm2 = LayerNorm(dim) - self.norm3 = LayerNorm(dim) + def __init__(self, dim, context_dim, n_heads, d_head): + self.attn1 = CrossAttention(dim, dim, n_heads, d_head) + self.ff = FeedForward(dim) + self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head) + self.norm1 = LayerNorm(dim) + self.norm2 = LayerNorm(dim) + self.norm3 = LayerNorm(dim) + + def __call__(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x - def __call__(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x class SpatialTransformer: - def __init__(self, channels, context_dim, n_heads, d_head): - self.norm = GroupNorm(32, channels) - assert channels == n_heads * d_head - self.proj_in = Conv2d(channels, n_heads * d_head, 1) - self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)] - self.proj_out = Conv2d(n_heads * d_head, channels, 1) + def __init__(self, channels, context_dim, n_heads, d_head): + self.norm = GroupNorm(32, channels) + assert channels == n_heads * d_head + self.proj_in = Conv2d(channels, n_heads * d_head, 1) + self.transformer_blocks = [ + BasicTransformerBlock(channels, context_dim, n_heads, d_head) + ] + self.proj_out = Conv2d(n_heads * d_head, channels, 1) + + def __call__(self, x, context=None): + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.reshape(b, c, h * w).permute(0, 2, 1) + for block in self.transformer_blocks: + x = block(x, context=context) + x = x.permute(0, 2, 1).reshape(b, c, h, w) + ret = self.proj_out(x) + x_in + return ret - def __call__(self, x, context=None): - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.reshape(b, c, h*w).permute(0,2,1) - for block in self.transformer_blocks: - x = block(x, context=context) - x = x.permute(0,2,1).reshape(b, c, h, w) - ret = self.proj_out(x) + x_in - return ret class Downsample: - def __init__(self, channels): - self.op = Conv2d(channels, channels, 3, stride=2, padding=1) + def __init__(self, channels): + self.op = Conv2d(channels, channels, 3, stride=2, padding=1) + + def __call__(self, x): + return self.op(x) - def __call__(self, x): - return self.op(x) class Upsample: - def __init__(self, channels): - self.conv = Conv2d(channels, channels, 3, padding=1) + def __init__(self, channels): + self.conv = Conv2d(channels, channels, 3, padding=1) + + def __call__(self, x): + bs, c, py, px = x.shape + x = ( + x.reshape(bs, c, py, 1, px, 1) + .expand(bs, c, py, 2, px, 2) + .reshape(bs, c, py * 2, px * 2) + ) + return self.conv(x) - def __call__(self, x): - bs,c,py,px = x.shape - x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2) - return self.conv(x) def timestep_embedding(timesteps, dim, max_period=10000): - half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() - args = timesteps * freqs - return Tensor.cat(args.cos(), args.sin()).reshape(1, -1) + half = dim // 2 + freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() + args = timesteps * freqs + return Tensor.cat(args.cos(), args.sin()).reshape(1, -1) + class UNetModel: - def __init__(self): - self.time_embed = [ - Linear(320, 1280), - Tensor.silu, - Linear(1280, 1280), - ] - self.input_blocks = [ - [Conv2d(4, 320, kernel_size=3, padding=1)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [Downsample(320)], - [ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [Downsample(640)], - [ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [Downsample(1280)], - [ResBlock(1280, 1280, 1280)], - [ResBlock(1280, 1280, 1280)] - ] - self.middle_block = [ - ResBlock(1280, 1280, 1280), - SpatialTransformer(1280, 768, 8, 160), - ResBlock(1280, 1280, 1280) - ] - self.output_blocks = [ - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280)], - [ResBlock(2560, 1280, 1280), Upsample(1280)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], - [ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)], - [ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6 - [ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)], - [ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)], - [ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], - ] - self.out = [ - GroupNorm(32, 320), - Tensor.silu, - Conv2d(320, 4, kernel_size=3, padding=1) - ] + def __init__(self): + self.time_embed = [ + Linear(320, 1280), + Tensor.silu, + Linear(1280, 1280), + ] + self.input_blocks = [ + [Conv2d(4, 320, kernel_size=3, padding=1)], + [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], + [ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)], + [Downsample(320)], + [ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)], + [ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)], + [Downsample(640)], + [ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], + [ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], + [Downsample(1280)], + [ResBlock(1280, 1280, 1280)], + [ResBlock(1280, 1280, 1280)], + ] + self.middle_block = [ + ResBlock(1280, 1280, 1280), + SpatialTransformer(1280, 768, 8, 160), + ResBlock(1280, 1280, 1280), + ] + self.output_blocks = [ + [ResBlock(2560, 1280, 1280)], + [ResBlock(2560, 1280, 1280)], + [ResBlock(2560, 1280, 1280), Upsample(1280)], + [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], + [ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)], + [ + ResBlock(1920, 1280, 1280), + SpatialTransformer(1280, 768, 8, 160), + Upsample(1280), + ], + [ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6 + [ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)], + [ + ResBlock(960, 1280, 640), + SpatialTransformer(640, 768, 8, 80), + Upsample(640), + ], + [ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)], + [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], + [ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)], + ] + self.out = [ + GroupNorm(32, 320), + Tensor.silu, + Conv2d(320, 4, kernel_size=3, padding=1), + ] - def __call__(self, x, timesteps=None, context=None): - # TODO: real time embedding - t_emb = timestep_embedding(timesteps, 320) - emb = t_emb.sequential(self.time_embed) + def __call__(self, x, timesteps=None, context=None): + # TODO: real time embedding + t_emb = timestep_embedding(timesteps, 320) + emb = t_emb.sequential(self.time_embed) - def run(x, bb): - if isinstance(bb, ResBlock): x = bb(x, emb) - elif isinstance(bb, SpatialTransformer): x = bb(x, context) - else: x = bb(x) - return x + def run(x, bb): + if isinstance(bb, ResBlock): + x = bb(x, emb) + elif isinstance(bb, SpatialTransformer): + x = bb(x, context) + else: + x = bb(x) + return x + + saved_inputs = [] + for i, b in enumerate(self.input_blocks): + # print("input block", i) + for bb in b: + x = run(x, bb) + saved_inputs.append(x) + for bb in self.middle_block: + x = run(x, bb) + for i, b in enumerate(self.output_blocks): + # print("output block", i) + x = x.cat(saved_inputs.pop(), dim=1) + for bb in b: + x = run(x, bb) + return x.sequential(self.out) - saved_inputs = [] - for i,b in enumerate(self.input_blocks): - #print("input block", i) - for bb in b: - x = run(x, bb) - saved_inputs.append(x) - for bb in self.middle_block: - x = run(x, bb) - for i,b in enumerate(self.output_blocks): - #print("output block", i) - x = x.cat(saved_inputs.pop(), dim=1) - for bb in b: - x = run(x, bb) - return x.sequential(self.out) class CLIPMLP: - def __init__(self): - self.fc1 = Linear(768, 3072) - self.fc2 = Linear(3072, 768) + def __init__(self): + self.fc1 = Linear(768, 3072) + self.fc2 = Linear(3072, 768) + + def __call__(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = hidden_states.quick_gelu() + hidden_states = self.fc2(hidden_states) + return hidden_states - def __call__(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = hidden_states.quick_gelu() - hidden_states = self.fc2(hidden_states) - return hidden_states class CLIPAttention: - def __init__(self): - self.embed_dim = 768 - self.num_heads = 12 - self.head_dim = self.embed_dim // self.num_heads - self.k_proj = Linear(self.embed_dim, self.embed_dim) - self.v_proj = Linear(self.embed_dim, self.embed_dim) - self.q_proj = Linear(self.embed_dim, self.embed_dim) - self.out_proj = Linear(self.embed_dim, self.embed_dim) + def __init__(self): + self.embed_dim = 768 + self.num_heads = 12 + self.head_dim = self.embed_dim // self.num_heads + self.k_proj = Linear(self.embed_dim, self.embed_dim) + self.v_proj = Linear(self.embed_dim, self.embed_dim) + self.q_proj = Linear(self.embed_dim, self.embed_dim) + self.out_proj = Linear(self.embed_dim, self.embed_dim) + + def __call__(self, hidden_states, causal_attention_mask): + bsz, tgt_len, embed_dim = hidden_states.shape + q, k, v = ( + self.q_proj(hidden_states), + self.k_proj(hidden_states), + self.v_proj(hidden_states), + ) + q, k, v = [ + x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + for x in (q, k, v) + ] + attn_output = Tensor.scaled_dot_product_attention( + q, k, v, attn_mask=causal_attention_mask + ) + return self.out_proj( + attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim) + ) - def __call__(self, hidden_states, causal_attention_mask): - bsz, tgt_len, embed_dim = hidden_states.shape - q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] - attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) - return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) class CLIPEncoderLayer: - def __init__(self): - self.self_attn = CLIPAttention() - self.layer_norm1 = LayerNorm(768) - self.mlp = CLIPMLP() - self.layer_norm2 = LayerNorm(768) + def __init__(self): + self.self_attn = CLIPAttention() + self.layer_norm1 = LayerNorm(768) + self.mlp = CLIPMLP() + self.layer_norm2 = LayerNorm(768) - def __call__(self, hidden_states, causal_attention_mask): - residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states, causal_attention_mask) - hidden_states = residual + hidden_states + def __call__(self, hidden_states, causal_attention_mask): + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states, causal_attention_mask) + hidden_states = residual + hidden_states - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states - return hidden_states class CLIPEncoder: - def __init__(self): - self.layers = [CLIPEncoderLayer() for i in range(12)] + def __init__(self): + self.layers = [CLIPEncoderLayer() for i in range(12)] + + def __call__(self, hidden_states, causal_attention_mask): + for l in self.layers: + hidden_states = l(hidden_states, causal_attention_mask) + return hidden_states - def __call__(self, hidden_states, causal_attention_mask): - for l in self.layers: - hidden_states = l(hidden_states, causal_attention_mask) - return hidden_states class CLIPTextEmbeddings: - def __init__(self): - self.token_embedding = Embedding(49408, 768) - self.position_embedding = Embedding(77, 768) + def __init__(self): + self.token_embedding = Embedding(49408, 768) + self.position_embedding = Embedding(77, 768) + + def __call__(self, input_ids, position_ids): + return self.token_embedding(input_ids) + self.position_embedding(position_ids) - def __call__(self, input_ids, position_ids): - return self.token_embedding(input_ids) + self.position_embedding(position_ids) class CLIPTextTransformer: - def __init__(self): - self.embeddings = CLIPTextEmbeddings() - self.encoder = CLIPEncoder() - self.final_layer_norm = LayerNorm(768) + def __init__(self): + self.embeddings = CLIPTextEmbeddings() + self.encoder = CLIPEncoder() + self.final_layer_norm = LayerNorm(768) + + def __call__(self, input_ids): + x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) + x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1)) + return self.final_layer_norm(x) - def __call__(self, input_ids): - x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) - x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1)) - return self.final_layer_norm(x) # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) @lru_cache() -def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz") +def default_bpe(): + return fetch( + "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", + "bpe_simple_vocab_16e6.txt.gz", + ) + def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + class ClipTokenizer: - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", + re.IGNORECASE, + ) - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) - if not pairs: - return token+'' + if not pairs: + return token + "" - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except Exception: - new_word.extend(word[i:]) - break + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(text.strip()).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + # Truncation, keeping two slots for start and end tokens. + if len(bpe_tokens) > 75: + bpe_tokens = bpe_tokens[:75] + return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(text.strip()).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - # Truncation, keeping two slots for start and end tokens. - if len(bpe_tokens) > 75: - bpe_tokens = bpe_tokens[:75] - return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1) class StableDiffusion: - def __init__(self): - self.alphas_cumprod = Tensor.empty(1000) - self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) - self.first_stage_model = AutoencoderKL() - self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) + def __init__(self): + self.alphas_cumprod = Tensor.empty(1000) + self.model = namedtuple("DiffusionModel", ["diffusion_model"])( + diffusion_model=UNetModel() + ) + self.first_stage_model = AutoencoderKL() + self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])( + transformer=namedtuple("Transformer", ["text_model"])( + text_model=CLIPTextTransformer() + ) + ) - def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev): - temperature = 1 - sigma_t = 0 - sqrt_one_minus_at = (1-a_t).sqrt() - #print(a_t, a_prev, sigma_t, sqrt_one_minus_at) + def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev): + temperature = 1 + sigma_t = 0 + sqrt_one_minus_at = (1 - a_t).sqrt() + # print(a_t, a_prev, sigma_t, sqrt_one_minus_at) - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - x_prev = a_prev.sqrt() * pred_x0 + dir_xt - return x_prev, pred_x0 + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + return x_prev, pred_x0 - def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale): - # put into diffuser - latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0)) - unconditional_latent, latent = latents[0:1], latents[1:2] + def get_model_output( + self, + unconditional_context, + context, + latent, + timestep, + unconditional_guidance_scale, + ): + # put into diffuser + latents = self.model.diffusion_model( + latent.expand(2, *latent.shape[1:]), + timestep, + unconditional_context.cat(context, dim=0), + ) + unconditional_latent, latent = latents[0:1], latents[1:2] - e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent) - return e_t + e_t = unconditional_latent + unconditional_guidance_scale * ( + latent - unconditional_latent + ) + return e_t - def decode(self, x): - x = self.first_stage_model.post_quant_conv(1/0.18215 * x) - x = self.first_stage_model.decoder(x) + def decode(self, x): + x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x) + x = self.first_stage_model.decoder(x) - # make image correct size and scale - x = (x + 1.0) / 2.0 - x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255 - return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x + # make image correct size and scale + x = (x + 1.0) / 2.0 + x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255 + return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x + + def __call__( + self, + unconditional_context, + context, + latent, + timestep, + alphas, + alphas_prev, + guidance, + ): + e_t = self.get_model_output( + unconditional_context, context, latent, timestep, guidance + ) + x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev) + # e_t_next = get_model_output(x_prev) + # e_t_prime = (e_t + e_t_next) / 2 + # x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index) + return x_prev.realize() - def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance): - e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance) - x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev) - #e_t_next = get_model_output(x_prev) - #e_t_prime = (e_t + e_t_next) / 2 - #x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index) - return x_prev.realize() # ** ldm.models.autoencoder.AutoencoderKL (done!) # 3x512x512 <--> 4x64x64 (16384) @@ -573,72 +702,118 @@ class StableDiffusion: # cond_stage_model.transformer.text_model if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion") - parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render") - parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") - parser.add_argument('--noshow', action='store_true', help="Don't show the image") - parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16") - parser.add_argument('--timing', action='store_true', help="Print timing per step") - parser.add_argument('--seed', type=int, help="Set the random latent seed") - parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength") - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Run Stable Diffusion", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--steps", type=int, default=5, help="Number of steps in diffusion" + ) + parser.add_argument( + "--prompt", + type=str, + default="a horse sized cat eating a bagel", + help="Phrase to render", + ) + parser.add_argument( + "--out", + type=str, + default=Path(tempfile.gettempdir()) / "rendered.png", + help="Output filename", + ) + parser.add_argument("--noshow", action="store_true", help="Don't show the image") + parser.add_argument( + "--fp16", action="store_true", help="Cast the weights to float16" + ) + parser.add_argument("--timing", action="store_true", help="Print timing per step") + parser.add_argument("--seed", type=int, help="Set the random latent seed") + parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength") + args = parser.parse_args() - Tensor.no_grad = True - model = StableDiffusion() + Tensor.no_grad = True + model = StableDiffusion() - # load in weights - load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False) + # load in weights + load_state_dict( + model, + torch_load( + fetch( + "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt", + "sd-v1-4.ckpt", + ) + )["state_dict"], + strict=False, + ) - if args.fp16: - for l in get_state_dict(model).values(): - l.assign(l.cast(dtypes.float16).realize()) + if args.fp16: + for l in get_state_dict(model).values(): + l.assign(l.cast(dtypes.float16).realize()) - # run through CLIP to get context - tokenizer = ClipTokenizer() - prompt = Tensor([tokenizer.encode(args.prompt)]) - context = model.cond_stage_model.transformer.text_model(prompt).realize() - print("got CLIP context", context.shape) + # run through CLIP to get context + tokenizer = ClipTokenizer() + prompt = Tensor([tokenizer.encode(args.prompt)]) + context = model.cond_stage_model.transformer.text_model(prompt).realize() + print("got CLIP context", context.shape) - prompt = Tensor([tokenizer.encode("")]) - unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize() - print("got unconditional CLIP context", unconditional_context.shape) + prompt = Tensor([tokenizer.encode("")]) + unconditional_context = model.cond_stage_model.transformer.text_model( + prompt + ).realize() + print("got unconditional CLIP context", unconditional_context.shape) - # done with clip model - del model.cond_stage_model + # done with clip model + del model.cond_stage_model - timesteps = list(range(1, 1000, 1000//args.steps)) - print(f"running for {timesteps} timesteps") - alphas = model.alphas_cumprod[Tensor(timesteps)] - alphas_prev = Tensor([1.0]).cat(alphas[:-1]) + timesteps = list(range(1, 1000, 1000 // args.steps)) + print(f"running for {timesteps} timesteps") + alphas = model.alphas_cumprod[Tensor(timesteps)] + alphas_prev = Tensor([1.0]).cat(alphas[:-1]) - # start with random noise - if args.seed is not None: Tensor._seed = args.seed - latent = Tensor.randn(1,4,64,64) + # start with random noise + if args.seed is not None: + Tensor._seed = args.seed + latent = Tensor.randn(1, 4, 64, 64) - @TinyJit - def run(model, *x): return model(*x).realize() + @TinyJit + def run(model, *x): + return model(*x).realize() - # this is diffusion - with Context(BEAM=getenv("LATEBEAM")): - for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): - GlobalCounters.reset() - t.set_description("%3d %3d" % (index, timestep)) - with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"): - tid = Tensor([index]) - latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance])) - if args.timing: Device[Device.DEFAULT].synchronize() - del run + # this is diffusion + with Context(BEAM=getenv("LATEBEAM")): + for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])): + GlobalCounters.reset() + t.set_description("%3d %3d" % (index, timestep)) + with Timing( + "step in ", + enabled=args.timing, + on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB", + ): + tid = Tensor([index]) + latent = run( + model, + unconditional_context, + context, + latent, + Tensor([timestep]), + alphas[tid], + alphas_prev[tid], + Tensor([args.guidance]), + ) + if args.timing: + Device[Device.DEFAULT].synchronize() + del run - # upsample latent space to image with autoencoder - x = model.decode(latent) - print(x.shape) + # upsample latent space to image with autoencoder + x = model.decode(latent) + print(x.shape) - # save image - from PIL import Image - import numpy as np - im = Image.fromarray(x.numpy().astype(np.uint8, copy=False)) - print(f"saving {args.out}") - im.save(args.out) - # Open image. - if not args.noshow: im.show() + # save image + from PIL import Image + import numpy as np + + im = Image.fromarray(x.numpy().astype(np.uint8, copy=False)) + print(f"saving {args.out}") + im.save(args.out) + # Open image. + if not args.noshow: + im.show() diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index e0bd2240b..497121b55 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -10,96 +10,108 @@ from tinygrad.tensor import Tensor from extra.datasets import fetch_cifar from extra.models.efficientnet import EfficientNet -class TinyConvNet: - def __init__(self, classes=10): - conv = 3 - inter_chan, out_chan = 8, 16 # for speed - self.c1 = Tensor.uniform(inter_chan,3,conv,conv) - self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) - self.l1 = Tensor.uniform(out_chan*6*6, classes) - def forward(self, x): - x = x.conv2d(self.c1).relu().max_pool2d() - x = x.conv2d(self.c2).relu().max_pool2d() - x = x.reshape(shape=[x.shape[0], -1]) - return x.dot(self.l1) +class TinyConvNet: + def __init__(self, classes=10): + conv = 3 + inter_chan, out_chan = 8, 16 # for speed + self.c1 = Tensor.uniform(inter_chan, 3, conv, conv) + self.c2 = Tensor.uniform(out_chan, inter_chan, conv, conv) + self.l1 = Tensor.uniform(out_chan * 6 * 6, classes) + + def forward(self, x): + x = x.conv2d(self.c1).relu().max_pool2d() + x = x.conv2d(self.c2).relu().max_pool2d() + x = x.reshape(shape=[x.shape[0], -1]) + return x.dot(self.l1) + if __name__ == "__main__": - IMAGENET = getenv("IMAGENET") - classes = 1000 if IMAGENET else 10 + IMAGENET = getenv("IMAGENET") + classes = 1000 if IMAGENET else 10 - TINY = getenv("TINY") - TRANSFER = getenv("TRANSFER") - if TINY: - model = TinyConvNet(classes) - elif TRANSFER: - model = EfficientNet(getenv("NUM", 0), classes, has_se=True) - model.load_from_pretrained() - else: - model = EfficientNet(getenv("NUM", 0), classes, has_se=False) + TINY = getenv("TINY") + TRANSFER = getenv("TRANSFER") + if TINY: + model = TinyConvNet(classes) + elif TRANSFER: + model = EfficientNet(getenv("NUM", 0), classes, has_se=True) + model.load_from_pretrained() + else: + model = EfficientNet(getenv("NUM", 0), classes, has_se=False) - parameters = get_parameters(model) - print("parameter count", len(parameters)) - optimizer = optim.Adam(parameters, lr=0.001) + parameters = get_parameters(model) + print("parameter count", len(parameters)) + optimizer = optim.Adam(parameters, lr=0.001) - BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) - print(f"training with batch size {BS} for {steps} steps") + BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048) + print(f"training with batch size {BS} for {steps} steps") - if IMAGENET: - from extra.datasets.imagenet import fetch_batch - def loader(q): - while 1: - try: - q.put(fetch_batch(BS)) - except Exception: - traceback.print_exc() - q = Queue(16) - for i in range(2): - p = Process(target=loader, args=(q,)) - p.daemon = True - p.start() - else: - X_train, Y_train, _, _ = fetch_cifar() - X_train = X_train.reshape((-1, 3, 32, 32)) - Y_train = Y_train.reshape((-1,)) + if IMAGENET: + from extra.datasets.imagenet import fetch_batch - with Tensor.train(): - for i in (t := trange(steps)): - if IMAGENET: - X, Y = q.get(True) - else: - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - X, Y = X_train.numpy()[samp], Y_train.numpy()[samp] + def loader(q): + while 1: + try: + q.put(fetch_batch(BS)) + except Exception: + traceback.print_exc() - st = time.time() - out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) - fp_time = (time.time()-st)*1000.0 + q = Queue(16) + for i in range(2): + p = Process(target=loader, args=(q,)) + p.daemon = True + p.start() + else: + X_train, Y_train, _, _ = fetch_cifar() + X_train = X_train.reshape((-1, 3, 32, 32)) + Y_train = Y_train.reshape((-1,)) - y = np.zeros((BS,classes), np.float32) - y[range(y.shape[0]),Y] = -classes - y = Tensor(y, requires_grad=False) - loss = out.log_softmax().mul(y).mean() + with Tensor.train(): + for i in (t := trange(steps)): + if IMAGENET: + X, Y = q.get(True) + else: + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + X, Y = X_train.numpy()[samp], Y_train.numpy()[samp] - optimizer.zero_grad() + st = time.time() + out = model.forward(Tensor(X.astype(np.float32), requires_grad=False)) + fp_time = (time.time() - st) * 1000.0 - st = time.time() - loss.backward() - bp_time = (time.time()-st)*1000.0 + y = np.zeros((BS, classes), np.float32) + y[range(y.shape[0]), Y] = -classes + y = Tensor(y, requires_grad=False) + loss = out.log_softmax().mul(y).mean() - st = time.time() - optimizer.step() - opt_time = (time.time()-st)*1000.0 + optimizer.zero_grad() - st = time.time() - loss = loss.numpy() - cat = out.argmax(axis=1).numpy() - accuracy = (cat == Y).mean() - finish_time = (time.time()-st)*1000.0 + st = time.time() + loss.backward() + bp_time = (time.time() - st) * 1000.0 - # printing - t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" % - (loss, accuracy, - fp_time, bp_time, opt_time, finish_time, - fp_time + bp_time + opt_time + finish_time)) + st = time.time() + optimizer.step() + opt_time = (time.time() - st) * 1000.0 - del out, y, loss + st = time.time() + loss = loss.numpy() + cat = out.argmax(axis=1).numpy() + accuracy = (cat == Y).mean() + finish_time = (time.time() - st) * 1000.0 + + # printing + t.set_description( + "loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" + % ( + loss, + accuracy, + fp_time, + bp_time, + opt_time, + finish_time, + fp_time + bp_time + opt_time + finish_time, + ) + ) + + del out, y, loss diff --git a/examples/train_resnet.py b/examples/train_resnet.py index 8feee8082..53d26d5ed 100755 --- a/examples/train_resnet.py +++ b/examples/train_resnet.py @@ -11,35 +11,38 @@ from extra.datasets import fetch_mnist class ComposeTransforms: - def __init__(self, trans): - self.trans = trans + def __init__(self, trans): + self.trans = trans + + def __call__(self, x): + for t in self.trans: + x = t(x) + return x - def __call__(self, x): - for t in self.trans: - x = t(x) - return x if __name__ == "__main__": - X_train, Y_train, X_test, Y_test = fetch_mnist() - X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) - X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) - classes = 10 + X_train, Y_train, X_test, Y_test = fetch_mnist() + X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) + X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) + classes = 10 - TRANSFER = getenv('TRANSFER') - model = ResNet(getenv('NUM', 18), num_classes=classes) - if TRANSFER: - model.load_from_pretrained() + TRANSFER = getenv("TRANSFER") + model = ResNet(getenv("NUM", 18), num_classes=classes) + if TRANSFER: + model.load_from_pretrained() - lr = 5e-3 - transform = ComposeTransforms([ - lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x], - lambda x: np.stack([np.asarray(xx) for xx in x], 0), - lambda x: x / 255.0, - lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32), - ]) - for _ in range(5): - optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9) - train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform) - evaluate(model, X_test, Y_test, num_classes=classes, transform=transform) - lr /= 1.2 - print(f'reducing lr to {lr:.7f}') + lr = 5e-3 + transform = ComposeTransforms( + [ + lambda x: [Image.fromarray(xx, mode="L").resize((64, 64)) for xx in x], + lambda x: np.stack([np.asarray(xx) for xx in x], 0), + lambda x: x / 255.0, + lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32), + ] + ) + for _ in range(5): + optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9) + train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform) + evaluate(model, X_test, Y_test, num_classes=classes, transform=transform) + lr /= 1.2 + print(f"reducing lr to {lr:.7f}") diff --git a/examples/transformer.py b/examples/transformer.py index af20a4490..3edd7e713 100755 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -7,36 +7,49 @@ from tinygrad.nn.optim import Adam from extra.training import train, evaluate from extra.models.transformer import Transformer + # dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py def make_dataset(): - ds = [] - for i in range(100): - for j in range(100): - s = i+j - ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10]) - random.shuffle(ds) - ds = np.array(ds).astype(np.float32) - ds_X = ds[:, 0:6] - ds_Y = np.copy(ds[:, 1:]) - ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:] - ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:] - return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test + ds = [] + for i in range(100): + for j in range(100): + s = i + j + ds.append( + [i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10] + ) + random.shuffle(ds) + ds = np.array(ds).astype(np.float32) + ds_X = ds[:, 0:6] + ds_Y = np.copy(ds[:, 1:]) + ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:] + ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:] + return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test + if __name__ == "__main__": - model = Transformer(10, 6, 2, 128, 4, 32) - X_train, Y_train, X_test, Y_test = make_dataset() - lr = 0.003 - for i in range(10): - optim = Adam(get_parameters(model), lr=lr) - train(model, X_train, Y_train, optim, 50, BS=64) - acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True) - lr /= 1.2 - print(f'reducing lr to {lr:.4f}') - if acc > 0.998: - wrong=0 - for k in range(len(Y_test_preds)): - if (Y_test_preds[k] != Y_test[k]).any(): - wrong+=1 - a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:] - print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})') - print(f'Wrong predictions: {wrong}, acc = {acc:.4f}') + model = Transformer(10, 6, 2, 128, 4, 32) + X_train, Y_train, X_test, Y_test = make_dataset() + lr = 0.003 + for i in range(10): + optim = Adam(get_parameters(model), lr=lr) + train(model, X_train, Y_train, optim, 50, BS=64) + acc, Y_test_preds = evaluate( + model, X_test, Y_test, num_classes=10, return_predict=True + ) + lr /= 1.2 + print(f"reducing lr to {lr:.4f}") + if acc > 0.998: + wrong = 0 + for k in range(len(Y_test_preds)): + if (Y_test_preds[k] != Y_test[k]).any(): + wrong += 1 + a, b, c, x = ( + X_test[k, :2], + X_test[k, 2:4], + Y_test[k, -3:], + Y_test_preds[k, -3:], + ) + print( + f"{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})" + ) + print(f"Wrong predictions: {wrong}, acc = {acc:.4f}") diff --git a/examples/vgg7.py b/examples/vgg7.py index a4a5835e5..7e8d31fc6 100644 --- a/examples/vgg7.py +++ b/examples/vgg7.py @@ -12,251 +12,276 @@ from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7 # amount of context erased by model CONTEXT = 7 + def get_sample_count(samples_dir): - try: - samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r") - v = samples_dir_count_file.readline() - samples_dir_count_file.close() - return int(v) - except: - return 0 + try: + samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r") + v = samples_dir_count_file.readline() + samples_dir_count_file.close() + return int(v) + except: + return 0 + def set_sample_count(samples_dir, sc): - with open(samples_dir + "/sample_count.txt", "w") as file: - file.write(str(sc) + "\n") + with open(samples_dir + "/sample_count.txt", "w") as file: + file.write(str(sc) + "\n") + if len(sys.argv) < 2: - print("python3 -m examples.vgg7 import MODELJSON MODEL") - print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json") - print(" into a safetensors file") - print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)") - print(" *this format is used by most other commands in this program*") - print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS") - print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors") - print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT") - print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it") - print(" output image has 7 pixels removed on all edges") - print(" do not run on large images, will have *hilarious* RAM use") - print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT") - print(" does the 'whole thing' (padding, tiling)") - print(" safe for large images, etc.") - print("python3 -m examples.vgg7 new MODEL") - print(" creates a new model (experimental)") - print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE") - print(" trains a model (experimental)") - print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)") - print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.") - print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png") - print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png") - print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,") - print(" my_samples/0b.png is the first original image)") - print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count") - print(" won't pad or tile, so keep image sizes sane") - print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE") - print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training") - print(" maintains/creates samples_count.txt automatically") - print(" unlike training, IMG_A must be exactly half the size of IMG_B") - sys.exit(1) + print("python3 -m examples.vgg7 import MODELJSON MODEL") + print( + " imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json" + ) + print(" into a safetensors file") + print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)") + print(" *this format is used by most other commands in this program*") + print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS") + print( + " imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors" + ) + print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT") + print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it") + print(" output image has 7 pixels removed on all edges") + print(" do not run on large images, will have *hilarious* RAM use") + print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT") + print(" does the 'whole thing' (padding, tiling)") + print(" safe for large images, etc.") + print("python3 -m examples.vgg7 new MODEL") + print(" creates a new model (experimental)") + print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE") + print(" trains a model (experimental)") + print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)") + print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.") + print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png") + print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png") + print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,") + print(" my_samples/0b.png is the first original image)") + print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count") + print(" won't pad or tile, so keep image sizes sane") + print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE") + print( + " creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training" + ) + print(" maintains/creates samples_count.txt automatically") + print(" unlike training, IMG_A must be exactly half the size of IMG_B") + sys.exit(1) cmd = sys.argv[1] vgg7 = Vgg7() + def nansbane(p): - if numpy.isnan(numpy.min(p.numpy())): - raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.") + if numpy.isnan(numpy.min(p.numpy())): + raise Exception( + "A NaN in the model has been detected. This model will not be interacted with to prevent further damage." + ) + def load_and_save(path, save): - if save: - for v in vgg7.get_parameters(): - nansbane(v) - st = get_state_dict(vgg7) - safe_save(st, path) - else: - st = safe_load(path) - load_state_dict(vgg7, st) - for v in vgg7.get_parameters(): - nansbane(v) + if save: + for v in vgg7.get_parameters(): + nansbane(v) + st = get_state_dict(vgg7) + safe_save(st, path) + else: + st = safe_load(path) + load_state_dict(vgg7, st) + for v in vgg7.get_parameters(): + nansbane(v) + if cmd == "import": - src = sys.argv[2] - model = sys.argv[3] + src = sys.argv[2] + model = sys.argv[3] - vgg7.load_waifu2x_json(json.load(open(src, "rb"))) + vgg7.load_waifu2x_json(json.load(open(src, "rb"))) - load_and_save(model, True) -elif cmd == "import_kinne": - # tinygrad wasn't doing safetensors when this example was written - # it's possible someone might have a model around using the resulting interim format - src = sys.argv[2] - model = sys.argv[3] - - index = 0 - for t in vgg7.get_parameters(): - fn = src + "/snoop_bin_" + str(index) + ".bin" - t.assign(Tensor(numpy.fromfile(fn, " numpy.ndarray: - """ - Loads an image in the shape expected by other functions in this module. - Doesn't Tensor it, in case you need to do further work with it. - """ - # file - na = numpy.array(Image.open(path)) - if na.shape[2] == 4: - # RGBA -> RGB (covers opaque images with alpha channels) - na = na[:,:,0:3] - # fix shape - na = numpy.moveaxis(na, [2,0,1], [0,1,2]) - # shape is now (3,h,w), add 1 - na = na.reshape(1,3,na.shape[1],na.shape[2]) - # change type - na = na.astype("float32") / 255.0 - return na + """ + Loads an image in the shape expected by other functions in this module. + Doesn't Tensor it, in case you need to do further work with it. + """ + # file + na = numpy.array(Image.open(path)) + if na.shape[2] == 4: + # RGBA -> RGB (covers opaque images with alpha channels) + na = na[:, :, 0:3] + # fix shape + na = numpy.moveaxis(na, [2, 0, 1], [0, 1, 2]) + # shape is now (3,h,w), add 1 + na = na.reshape(1, 3, na.shape[1], na.shape[2]) + # change type + na = na.astype("float32") / 255.0 + return na + def image_save(path, na: numpy.ndarray): - """ - Saves an image of the shape expected by other functions in this module. - However, note this expects a numpy array. - """ - # change type - na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8") - # shape is now (1,3,h,w), remove 1 - na = na.reshape(3,na.shape[2],na.shape[3]) - # fix shape - na = numpy.moveaxis(na, [0,1,2], [2,0,1]) - # shape is now (h,w,3) - # file - Image.fromarray(na).save(path) + """ + Saves an image of the shape expected by other functions in this module. + However, note this expects a numpy array. + """ + # change type + na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8") + # shape is now (1,3,h,w), remove 1 + na = na.reshape(3, na.shape[2], na.shape[3]) + # fix shape + na = numpy.moveaxis(na, [0, 1, 2], [2, 0, 1]) + # shape is now (h,w,3) + # file + Image.fromarray(na).save(path) + # The Model + class Conv3x3Biased: - """ - A 3x3 convolution layer with some utility functions. - """ - def __init__(self, inC, outC, last = False): - # The properties must be named as "W" and "b". - # This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x - # though this cannot necessarily account for transposition and other such things. + """ + A 3x3 convolution layer with some utility functions. + """ - # Massively overstate the weights to get them to be focused on, - # since otherwise the biases overrule everything - self.W = Tensor.uniform(outC, inC, 3, 3) * 16.0 - # Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is? - # Values-wise, entirely different blatant cheat. - # In most cases, use uniform bias, but tiny. - # For the last layer, use just 0.5, constant. - if last: - self.b = Tensor.zeros(1, outC, 1, 1) + 0.5 - else: - self.b = Tensor.uniform(1, outC, 1, 1) + def __init__(self, inC, outC, last=False): + # The properties must be named as "W" and "b". + # This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x + # though this cannot necessarily account for transposition and other such things. - def forward(self, x): - # You might be thinking, "but what about padding?" - # Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it. - return x.conv2d(self.W).add(self.b) + # Massively overstate the weights to get them to be focused on, + # since otherwise the biases overrule everything + self.W = Tensor.uniform(outC, inC, 3, 3) * 16.0 + # Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is? + # Values-wise, entirely different blatant cheat. + # In most cases, use uniform bias, but tiny. + # For the last layer, use just 0.5, constant. + if last: + self.b = Tensor.zeros(1, outC, 1, 1) + 0.5 + else: + self.b = Tensor.uniform(1, outC, 1, 1) - def get_parameters(self) -> list: - return [self.W, self.b] + def forward(self, x): + # You might be thinking, "but what about padding?" + # Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it. + return x.conv2d(self.W).add(self.b) + + def get_parameters(self) -> list: + return [self.W, self.b] + + def load_waifu2x_json(self, layer: dict): + # Weights in this file are outChannel,inChannel,X,Y. + # Not outChannel,inChannel,Y,X. + # Therefore, transpose it before assignment. + # I have long since forgotten how I worked this out. + self.W.assign( + Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3) + ) + self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape)) - def load_waifu2x_json(self, layer: dict): - # Weights in this file are outChannel,inChannel,X,Y. - # Not outChannel,inChannel,Y,X. - # Therefore, transpose it before assignment. - # I have long since forgotten how I worked this out. - self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3)) - self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape)) class Vgg7: - """ - The 'vgg7' waifu2x network. - Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test. - """ - - def __init__(self): - self.conv1 = Conv3x3Biased(3, 32) - self.conv2 = Conv3x3Biased(32, 32) - self.conv3 = Conv3x3Biased(32, 64) - self.conv4 = Conv3x3Biased(64, 64) - self.conv5 = Conv3x3Biased(64, 128) - self.conv6 = Conv3x3Biased(128, 128) - self.conv7 = Conv3x3Biased(128, 3, True) - - def forward(self, x): """ - Forward pass: Actually runs the network. - Input format: (1, 3, Y, X) - Output format: (1, 3, Y - 14, X - 14) - (the - 14 represents the 7-pixel context border that is lost) + The 'vgg7' waifu2x network. + Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test. """ - x = self.conv1.forward(x).leakyrelu(0.1) - x = self.conv2.forward(x).leakyrelu(0.1) - x = self.conv3.forward(x).leakyrelu(0.1) - x = self.conv4.forward(x).leakyrelu(0.1) - x = self.conv5.forward(x).leakyrelu(0.1) - x = self.conv6.forward(x).leakyrelu(0.1) - x = self.conv7.forward(x) - return x - def get_parameters(self) -> list: - return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters() + def __init__(self): + self.conv1 = Conv3x3Biased(3, 32) + self.conv2 = Conv3x3Biased(32, 32) + self.conv3 = Conv3x3Biased(32, 64) + self.conv4 = Conv3x3Biased(64, 64) + self.conv5 = Conv3x3Biased(64, 128) + self.conv6 = Conv3x3Biased(128, 128) + self.conv7 = Conv3x3Biased(128, 3, True) - def load_from_pretrained(self, intent = "art", subtype = "scale2.0x"): - """ - Downloads a nagadomi/waifu2x JSON weight file and loads it. - """ - import json - data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes()) - self.load_waifu2x_json(data) + def forward(self, x): + """ + Forward pass: Actually runs the network. + Input format: (1, 3, Y, X) + Output format: (1, 3, Y - 14, X - 14) + (the - 14 represents the 7-pixel context border that is lost) + """ + x = self.conv1.forward(x).leakyrelu(0.1) + x = self.conv2.forward(x).leakyrelu(0.1) + x = self.conv3.forward(x).leakyrelu(0.1) + x = self.conv4.forward(x).leakyrelu(0.1) + x = self.conv5.forward(x).leakyrelu(0.1) + x = self.conv6.forward(x).leakyrelu(0.1) + x = self.conv7.forward(x) + return x - def load_waifu2x_json(self, data: list): - """ - Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json - data (passed in) is assumed to be the output of json.load or some similar on such a file - """ - self.conv1.load_waifu2x_json(data[0]) - self.conv2.load_waifu2x_json(data[1]) - self.conv3.load_waifu2x_json(data[2]) - self.conv4.load_waifu2x_json(data[3]) - self.conv5.load_waifu2x_json(data[4]) - self.conv6.load_waifu2x_json(data[5]) - self.conv7.load_waifu2x_json(data[6]) + def get_parameters(self) -> list: + return ( + self.conv1.get_parameters() + + self.conv2.get_parameters() + + self.conv3.get_parameters() + + self.conv4.get_parameters() + + self.conv5.get_parameters() + + self.conv6.get_parameters() + + self.conv7.get_parameters() + ) - def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray: - """ - Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it. - Note that you really shouldn't try to run anything not (1, 3, *, *) through this. - """ - # Constant that only really gets repeated a ton here. - context = 7 - context2 = context + context + def load_from_pretrained(self, intent="art", subtype="scale2.0x"): + """ + Downloads a nagadomi/waifu2x JSON weight file and loads it. + """ + import json - # Notably, numpy is used here because it makes this fine manipulation a lot simpler. - # Scaling first - repeat on axis 2 and axis 3 (Y & X) - image = image.repeat(2, 2).repeat(2, 3) + data = json.loads( + fetch( + "https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + + intent + + "/" + + subtype + + "_model.json" + ).read_bytes() + ) + self.load_waifu2x_json(data) - # Resulting image buffer. This is made before the input is padded, - # since the input has the padded shape right now. - image_out = numpy.zeros(image.shape) + def load_waifu2x_json(self, data: list): + """ + Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json + data (passed in) is assumed to be the output of json.load or some similar on such a file + """ + self.conv1.load_waifu2x_json(data[0]) + self.conv2.load_waifu2x_json(data[1]) + self.conv3.load_waifu2x_json(data[2]) + self.conv4.load_waifu2x_json(data[3]) + self.conv5.load_waifu2x_json(data[4]) + self.conv6.load_waifu2x_json(data[5]) + self.conv7.load_waifu2x_json(data[6]) - # Padding next. Note that this padding is done on the whole image. - # Padding the tiles would lose critical context, cause seams, etc. - image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge") + def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray: + """ + Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it. + Note that you really shouldn't try to run anything not (1, 3, *, *) through this. + """ + # Constant that only really gets repeated a ton here. + context = 7 + context2 = context + context - # Now for tiling. - # The output tile size is the usable output from an input tile (tile_size). - # As such, the tiles overlap. - out_tile_size = tile_size - context2 - for out_y in range(0, image_out.shape[2], out_tile_size): - for out_x in range(0, image_out.shape[3], out_tile_size): - # Input is sourced from the same coordinates, but some stuff ought to be - # noted here for future reference: - # + out_x/y's equivalent position w/ the padding is out_x + context. - # + The output, however, is without context. Input needs context. - # + Therefore, the input rectangle is expanded on all sides by context. - # + Therefore, the input position has the context subtracted again. - # + Therefore: - in_y = out_y - in_x = out_x - # not shown: in_w/in_h = tile_size (as opposed to out_tile_size) - # Extract tile. - # Note that numpy will auto-crop this at the bottom-right. - # This will never be a problem, as tiles are specifically chosen within the padded section. - tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size] - # Extracted tile dimensions -> output dimensions - # This is important because of said cropping, otherwise it'd be interior tile size. - out_h = tile.shape[2] - context2 - out_w = tile.shape[3] - context2 - # Process tile. - tile_t = Tensor(tile) - tile_fwd_t = self.forward(tile_t) - # Replace tile. - image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy() + # Notably, numpy is used here because it makes this fine manipulation a lot simpler. + # Scaling first - repeat on axis 2 and axis 3 (Y & X) + image = image.repeat(2, 2).repeat(2, 3) - return image_out + # Resulting image buffer. This is made before the input is padded, + # since the input has the padded shape right now. + image_out = numpy.zeros(image.shape) + # Padding next. Note that this padding is done on the whole image. + # Padding the tiles would lose critical context, cause seams, etc. + image = numpy.pad( + image, [[0, 0], [0, 0], [context, context], [context, context]], mode="edge" + ) + + # Now for tiling. + # The output tile size is the usable output from an input tile (tile_size). + # As such, the tiles overlap. + out_tile_size = tile_size - context2 + for out_y in range(0, image_out.shape[2], out_tile_size): + for out_x in range(0, image_out.shape[3], out_tile_size): + # Input is sourced from the same coordinates, but some stuff ought to be + # noted here for future reference: + # + out_x/y's equivalent position w/ the padding is out_x + context. + # + The output, however, is without context. Input needs context. + # + Therefore, the input rectangle is expanded on all sides by context. + # + Therefore, the input position has the context subtracted again. + # + Therefore: + in_y = out_y + in_x = out_x + # not shown: in_w/in_h = tile_size (as opposed to out_tile_size) + # Extract tile. + # Note that numpy will auto-crop this at the bottom-right. + # This will never be a problem, as tiles are specifically chosen within the padded section. + tile = image[:, :, in_y : in_y + tile_size, in_x : in_x + tile_size] + # Extracted tile dimensions -> output dimensions + # This is important because of said cropping, otherwise it'd be interior tile size. + out_h = tile.shape[2] - context2 + out_w = tile.shape[3] - context2 + # Process tile. + tile_t = Tensor(tile) + tile_fwd_t = self.forward(tile_t) + # Replace tile. + image_out[ + :, :, out_y : out_y + out_h, out_x : out_x + out_w + ] = tile_fwd_t.numpy() + + return image_out diff --git a/examples/vit.py b/examples/vit.py index bf9a8f5d3..1e146af15 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -4,6 +4,7 @@ from PIL import Image from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, fetch from extra.models.vit import ViT + """ fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz" import tensorflow as tf @@ -15,27 +16,33 @@ with tf.io.gfile.GFile(fn, "rb") as f: Tensor.training = False if getenv("LARGE", 0) == 1: - m = ViT(embed_dim=768, num_heads=12) + m = ViT(embed_dim=768, num_heads=12) else: - # tiny - m = ViT(embed_dim=192, num_heads=3) + # tiny + m = ViT(embed_dim=192, num_heads=3) m.load_from_pretrained() # category labels -lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) +lbls = ast.literal_eval( + fetch( + "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt" + ).read_text() +) -#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" +# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0" # junk img = Image.open(fetch(url)) aspect_ratio = img.size[0] / img.size[1] -img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) +img = img.resize( + (int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0))) +) img = np.array(img) -y0,x0=(np.asarray(img.shape)[:2]-224)//2 -img = img[y0:y0+224, x0:x0+224] -img = np.moveaxis(img, [2,0,1], [0,1,2]) -img = img.astype(np.float32)[:3].reshape(1,3,224,224) +y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2 +img = img[y0 : y0 + 224, x0 : x0 + 224] +img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) +img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) img /= 255.0 img -= 0.5 img /= 0.5 diff --git a/examples/vits.py b/examples/vits.py index 3ae13ddc5..95adab95e 100644 --- a/examples/vits.py +++ b/examples/vits.py @@ -14,632 +14,1776 @@ from unidecode import unidecode LRELU_SLOPE = 0.1 + class Synthesizer: - def __init__(self, n_vocab, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, n_speakers=0, gin_channels=0, use_sdp=True, emotion_embedding=False, **kwargs): - self.n_vocab, self.spec_channels, self.inter_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.resblock, self.resblock_kernel_sizes, self.resblock_dilation_sizes, self.upsample_rates, self.upsample_initial_channel, self.upsample_kernel_sizes, self.segment_size, self.n_speakers, self.gin_channels, self.use_sdp = n_vocab, spec_channels, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, segment_size, n_speakers, gin_channels, use_sdp - self.enc_p = TextEncoder(n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, emotion_embedding) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) if use_sdp else DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) - if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels) - def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, batch_size=500): - x, m_p, logs_p, x_mask = self.enc_p.forward(x.realize(), x_lengths.realize(), emotion_embedding.realize() if emotion_embedding is not None else emotion_embedding) - g = self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) if self.n_speakers > 0 else None - logw = self.dp.forward(x, x_mask.realize(), g=g.realize(), reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0) - w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale) - y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64) - return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size) - def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size): - max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale - y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype) - attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1) - attn = generate_path(w_ceil, attn_mask) - m_p_2 = attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p_2 = attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - z_p = m_p_2 + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) * logs_p_2.exp() * noise_scale - # Pad flow forward inputs to enable JIT - row_len = y_mask.shape[2] - assert batch_size > row_len, "batch size is too small" - y_mask = y_mask.pad(((0, 0), (0, 0), (0, batch_size - row_len)), 0).cast(z_p.dtype) - # New y_mask tensor to remove sts mask - y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad) - z_p = z_p.squeeze(0).pad(((0, 0), (0, batch_size - z_p.shape[2])), 1).unsqueeze(0) - z = self.flow.forward(z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True) - result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len) - o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length] - if max_y_length_estimate_scale is not None: - length_scaler = o.shape[-1] / max_y_length - o.realize() - real_max_y_length = y_lengths.max().numpy() - if real_max_y_length > max_y_length: - logging.warning(f"Underestimated max length by {(((real_max_y_length / max_y_length) * 100) - 100):.2f}%, recomputing inference without estimate...") - return self.generate(g, logs_p, m_p, max_len, None, noise_scale, w_ceil, x, x_mask, y_lengths) - if real_max_y_length < max_y_length: - overestimation = ((max_y_length / real_max_y_length) * 100) - 100 - logging.info(f"Overestimated max length by {overestimation:.2f}%") - if overestimation > 10: logging.warning("Warning: max length overestimated by more than 10%") - o = o[:, :, :(real_max_y_length * length_scaler).astype(np.int32)] - return o + def __init__( + self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + emotion_embedding=False, + **kwargs, + ): + ( + self.n_vocab, + self.spec_channels, + self.inter_channels, + self.hidden_channels, + self.filter_channels, + self.n_heads, + self.n_layers, + self.kernel_size, + self.p_dropout, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + self.segment_size, + self.n_speakers, + self.gin_channels, + self.use_sdp, + ) = ( + n_vocab, + spec_channels, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + segment_size, + n_speakers, + gin_channels, + use_sdp, + ) + self.enc_p = TextEncoder( + n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + emotion_embedding, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + self.dp = ( + StochasticDurationPredictor( + hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels + ) + if use_sdp + else DurationPredictor( + hidden_channels, 256, 3, 0.5, gin_channels=gin_channels + ) + ) + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def infer( + self, + x, + x_lengths, + sid=None, + noise_scale=1.0, + length_scale=1, + noise_scale_w=1.0, + max_len=None, + emotion_embedding=None, + max_y_length_estimate_scale=None, + batch_size=500, + ): + x, m_p, logs_p, x_mask = self.enc_p.forward( + x.realize(), + x_lengths.realize(), + emotion_embedding.realize() + if emotion_embedding is not None + else emotion_embedding, + ) + g = ( + self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) + if self.n_speakers > 0 + else None + ) + logw = self.dp.forward( + x, + x_mask.realize(), + g=g.realize(), + reverse=self.use_sdp, + noise_scale=noise_scale_w if self.use_sdp else 1.0, + ) + w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale) + y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64) + return self.generate( + g, + logs_p, + m_p, + max_len, + max_y_length_estimate_scale, + noise_scale, + w_ceil, + x, + x_mask, + y_lengths, + batch_size, + ) + + def generate( + self, + g, + logs_p, + m_p, + max_len, + max_y_length_estimate_scale, + noise_scale, + w_ceil, + x, + x_mask, + y_lengths, + batch_size, + ): + max_y_length = ( + y_lengths.max().numpy() + if max_y_length_estimate_scale is None + else max(15, x.shape[-1]) * max_y_length_estimate_scale + ) + y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype) + attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1) + attn = generate_path(w_ceil, attn_mask) + m_p_2 = ( + attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p_2 = ( + attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) + ) # [b, t', t], [b, t, d] -> [b, d, t'] + z_p = ( + m_p_2 + + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) + * logs_p_2.exp() + * noise_scale + ) + # Pad flow forward inputs to enable JIT + row_len = y_mask.shape[2] + assert batch_size > row_len, "batch size is too small" + y_mask = y_mask.pad(((0, 0), (0, 0), (0, batch_size - row_len)), 0).cast( + z_p.dtype + ) + # New y_mask tensor to remove sts mask + y_mask = Tensor( + y_mask.numpy(), + device=y_mask.device, + dtype=y_mask.dtype, + requires_grad=y_mask.requires_grad, + ) + z_p = ( + z_p.squeeze(0).pad(((0, 0), (0, batch_size - z_p.shape[2])), 1).unsqueeze(0) + ) + z = self.flow.forward( + z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True + ) + result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len) + o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length] + if max_y_length_estimate_scale is not None: + length_scaler = o.shape[-1] / max_y_length + o.realize() + real_max_y_length = y_lengths.max().numpy() + if real_max_y_length > max_y_length: + logging.warning( + f"Underestimated max length by {(((real_max_y_length / max_y_length) * 100) - 100):.2f}%, recomputing inference without estimate..." + ) + return self.generate( + g, + logs_p, + m_p, + max_len, + None, + noise_scale, + w_ceil, + x, + x_mask, + y_lengths, + ) + if real_max_y_length < max_y_length: + overestimation = ((max_y_length / real_max_y_length) * 100) - 100 + logging.info(f"Overestimated max length by {overestimation:.2f}%") + if overestimation > 10: + logging.warning( + "Warning: max length overestimated by more than 10%" + ) + o = o[:, :, : (real_max_y_length * length_scaler).astype(np.int32)] + return o + class StochasticDurationPredictor: - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.n_flows, self.gin_channels = in_channels, filter_channels, kernel_size, p_dropout, n_flows, gin_channels - self.log_flow, self.flows = Log(), [ElementwiseAffine(2)] - for _ in range(n_flows): - self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.flows.append(Flip()) - self.post_pre, self.post_proj = nn.Conv1d(1, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - self.post_flows = [ElementwiseAffine(2)] - for _ in range(4): - self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.post_flows.append(Flip()) - self.pre, self.proj = nn.Conv1d(in_channels, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - @TinyJit - def forward(self, x: Tensor, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = self.pre(x.detach()) - if g is not None: x = x + self.cond(g.detach()) - x = self.convs.forward(x, x_mask) - x = self.proj(x) * x_mask - if not reverse: - flows = self.flows - assert w is not None - log_det_tot_q = 0 - h_w = self.post_proj(self.post_convs.forward(self.post_pre(w), x_mask)) * x_mask - e_q = Tensor.randn(w.size(0), 2, w.size(2), dtype=x.dtype).to(device=x.device) * x_mask - z_q = e_q - for flow in self.post_flows: - z_q, log_det_q = flow.forward(z_q, x_mask, g=(x + h_w)) - log_det_tot_q += log_det_q - z_u, z1 = z_q.split([1, 1], 1) - u = z_u.sigmoid() * x_mask - z0 = (w - u) * x_mask - log_det_tot_q += Tensor.sum((z_u.logsigmoid() + (-z_u).logsigmoid()) * x_mask, [1,2]) - log_q = Tensor.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - log_det_tot_q - log_det_tot = 0 - z0, log_det = self.log_flow.forward(z0, x_mask) - log_det_tot += log_det - z = z0.cat(z1, 1) - for flow in flows: - z, log_det = flow.forward(z, x_mask, g=x, reverse=reverse) - log_det_tot = log_det_tot + log_det - nll = Tensor.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - log_det_tot - return (nll + log_q).realize() # [b] - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = Tensor.randn(x.shape[0], 2, x.shape[2], dtype=x.dtype).to(device=x.device) * noise_scale - for flow in flows: z = flow.forward(z, x_mask, g=x, reverse=reverse) - z0, z1 = split(z, [1, 1], 1) - return z0.realize() + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): + filter_channels = in_channels # it needs to be removed from future version. + ( + self.in_channels, + self.filter_channels, + self.kernel_size, + self.p_dropout, + self.n_flows, + self.gin_channels, + ) = ( + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows, + gin_channels, + ) + self.log_flow, self.flows = Log(), [ElementwiseAffine(2)] + for _ in range(n_flows): + self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(Flip()) + self.post_pre, self.post_proj = nn.Conv1d(1, filter_channels, 1), nn.Conv1d( + filter_channels, filter_channels, 1 + ) + self.post_convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = [ElementwiseAffine(2)] + for _ in range(4): + self.post_flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(Flip()) + self.pre, self.proj = nn.Conv1d(in_channels, filter_channels, 1), nn.Conv1d( + filter_channels, filter_channels, 1 + ) + self.convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + @TinyJit + def forward( + self, x: Tensor, x_mask, w=None, g=None, reverse=False, noise_scale=1.0 + ): + x = self.pre(x.detach()) + if g is not None: + x = x + self.cond(g.detach()) + x = self.convs.forward(x, x_mask) + x = self.proj(x) * x_mask + if not reverse: + flows = self.flows + assert w is not None + log_det_tot_q = 0 + h_w = ( + self.post_proj(self.post_convs.forward(self.post_pre(w), x_mask)) + * x_mask + ) + e_q = ( + Tensor.randn(w.size(0), 2, w.size(2), dtype=x.dtype).to(device=x.device) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, log_det_q = flow.forward(z_q, x_mask, g=(x + h_w)) + log_det_tot_q += log_det_q + z_u, z1 = z_q.split([1, 1], 1) + u = z_u.sigmoid() * x_mask + z0 = (w - u) * x_mask + log_det_tot_q += Tensor.sum( + (z_u.logsigmoid() + (-z_u).logsigmoid()) * x_mask, [1, 2] + ) + log_q = ( + Tensor.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - log_det_tot_q + ) + log_det_tot = 0 + z0, log_det = self.log_flow.forward(z0, x_mask) + log_det_tot += log_det + z = z0.cat(z1, 1) + for flow in flows: + z, log_det = flow.forward(z, x_mask, g=x, reverse=reverse) + log_det_tot = log_det_tot + log_det + nll = ( + Tensor.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - log_det_tot + ) + return (nll + log_q).realize() # [b] + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + Tensor.randn(x.shape[0], 2, x.shape[2], dtype=x.dtype).to(device=x.device) + * noise_scale + ) + for flow in flows: + z = flow.forward(z, x_mask, g=x, reverse=reverse) + z0, z1 = split(z, [1, 1], 1) + return z0.realize() + class DurationPredictor: - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): - self.in_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.gin_channels = in_channels, filter_channels, kernel_size, p_dropout, gin_channels - self.conv_1, self.norm_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2), LayerNorm(filter_channels) - self.conv_2, self.norm_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) - def forward(self, x: Tensor, x_mask, g=None): - x = x.detach() - if g is not None: x = x + self.cond(g.detach()) - x = self.conv_1(x * x_mask).relu() - x = self.norm_1(x).dropout(self.p_dropout) - x = self.conv_2(x * x_mask).relu(x) - x = self.norm_2(x).dropout(self.p_dropout) - return self.proj(x * x_mask) * x_mask + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + ( + self.in_channels, + self.filter_channels, + self.kernel_size, + self.p_dropout, + self.gin_channels, + ) = (in_channels, filter_channels, kernel_size, p_dropout, gin_channels) + self.conv_1, self.norm_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ), LayerNorm(filter_channels) + self.conv_2, self.norm_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ), LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x: Tensor, x_mask, g=None): + x = x.detach() + if g is not None: + x = x + self.cond(g.detach()) + x = self.conv_1(x * x_mask).relu() + x = self.norm_1(x).dropout(self.p_dropout) + x = self.conv_2(x * x_mask).relu(x) + x = self.norm_2(x).dropout(self.p_dropout) + return self.proj(x * x_mask) * x_mask + class TextEncoder: - def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, emotion_embedding): - self.n_vocab, self.out_channels, self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout = n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout - if n_vocab!=0:self.emb = nn.Embedding(n_vocab, hidden_channels) - if emotion_embedding: self.emo_proj = nn.Linear(1024, hidden_channels) - self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - @TinyJit - def forward(self, x: Tensor, x_lengths: Tensor, emotion_embedding=None): - if self.n_vocab!=0: x = (self.emb(x) * math.sqrt(self.hidden_channels)) - if emotion_embedding: x = x + self.emo_proj(emotion_embedding).unsqueeze(1) - x = x.transpose(1, -1) # [b, t, h] -transpose-> [b, h, t] - x_mask = sequence_mask(x_lengths, x.shape[2]).unsqueeze(1).cast(x.dtype) - x = self.encoder.forward(x * x_mask, x_mask) - m, logs = split(self.proj(x) * x_mask, self.out_channels, dim=1) - return x.realize(), m.realize(), logs.realize(), x_mask.realize() + def __init__( + self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + emotion_embedding, + ): + ( + self.n_vocab, + self.out_channels, + self.hidden_channels, + self.filter_channels, + self.n_heads, + self.n_layers, + self.kernel_size, + self.p_dropout, + ) = ( + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + if n_vocab != 0: + self.emb = nn.Embedding(n_vocab, hidden_channels) + if emotion_embedding: + self.emo_proj = nn.Linear(1024, hidden_channels) + self.encoder = Encoder( + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + @TinyJit + def forward(self, x: Tensor, x_lengths: Tensor, emotion_embedding=None): + if self.n_vocab != 0: + x = self.emb(x) * math.sqrt(self.hidden_channels) + if emotion_embedding: + x = x + self.emo_proj(emotion_embedding).unsqueeze(1) + x = x.transpose(1, -1) # [b, t, h] -transpose-> [b, h, t] + x_mask = sequence_mask(x_lengths, x.shape[2]).unsqueeze(1).cast(x.dtype) + x = self.encoder.forward(x * x_mask, x_mask) + m, logs = split(self.proj(x) * x_mask, self.out_channels, dim=1) + return x.realize(), m.realize(), logs.realize(), x_mask.realize() + class ResidualCouplingBlock: - def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): - self.channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.n_flows, self.gin_channels = channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows, gin_channels - self.flows = [] - for _ in range(n_flows): - self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) - self.flows.append(Flip()) - @TinyJit - def forward(self, x, x_mask, g=None, reverse=False): - for flow in reversed(self.flows) if reverse else self.flows: x = flow.forward(x, x_mask, g=g, reverse=reverse) - return x.realize() + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + ( + self.channels, + self.hidden_channels, + self.kernel_size, + self.dilation_rate, + self.n_layers, + self.n_flows, + self.gin_channels, + ) = ( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows, + gin_channels, + ) + self.flows = [] + for _ in range(n_flows): + self.flows.append( + ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(Flip()) + + @TinyJit + def forward(self, x, x_mask, g=None, reverse=False): + for flow in reversed(self.flows) if reverse else self.flows: + x = flow.forward(x, x_mask, g=g, reverse=reverse) + return x.realize() + class PosteriorEncoder: - def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0): - self.in_channels, self.out_channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.gin_channels = in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels - self.pre, self.proj = nn.Conv1d(in_channels, hidden_channels, 1), nn.Conv1d(hidden_channels, out_channels * 2, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - def forward(self, x, x_lengths, g=None): - x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).cast(x.dtype) - stats = self.proj(self.enc.forward(self.pre(x) * x_mask, x_mask, g=g)) * x_mask - m, logs = stats.split(self.out_channels, dim=1) - z = (m + Tensor.randn(m.shape, m.dtype) * logs.exp()) * x_mask - return z, m, logs, x_mask + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + ( + self.in_channels, + self.out_channels, + self.hidden_channels, + self.kernel_size, + self.dilation_rate, + self.n_layers, + self.gin_channels, + ) = ( + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + ) + self.pre, self.proj = nn.Conv1d(in_channels, hidden_channels, 1), nn.Conv1d( + hidden_channels, out_channels * 2, 1 + ) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + + def forward(self, x, x_lengths, g=None): + x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).cast(x.dtype) + stats = self.proj(self.enc.forward(self.pre(x) * x_mask, x_mask, g=g)) * x_mask + m, logs = stats.split(self.out_channels, dim=1) + z = (m + Tensor.randn(m.shape, m.dtype) * logs.exp()) * x_mask + return z, m, logs, x_mask + class Generator: - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): - self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len(upsample_rates) - self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = ResBlock1 if resblock == '1' else ResBlock2 - self.ups = [nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2) for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes))] - self.resblocks = [] - self.upsample_rates = upsample_rates - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) - if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - @TinyJit - def forward(self, x: Tensor, g=None): - x = self.conv_pre(x) - if g is not None: x = x + self.cond(g) - for i in range(self.num_upsamples): - x = self.ups[i](x.leakyrelu(LRELU_SLOPE)) - xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels)) - x = (xs / self.num_kernels).realize() - res = self.conv_post(x.leakyrelu()).tanh().realize() - return res + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + self.num_kernels, self.num_upsamples = len(resblock_kernel_sizes), len( + upsample_rates + ) + self.conv_pre = nn.Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = ResBlock1 if resblock == "1" else ResBlock2 + self.ups = [ + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)) + ] + self.resblocks = [] + self.upsample_rates = upsample_rates + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + @TinyJit + def forward(self, x: Tensor, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + for i in range(self.num_upsamples): + x = self.ups[i](x.leakyrelu(LRELU_SLOPE)) + xs = sum( + self.resblocks[i * self.num_kernels + j].forward(x) + for j in range(self.num_kernels) + ) + x = (xs / self.num_kernels).realize() + res = self.conv_post(x.leakyrelu()).tanh().realize() + return res + class LayerNorm(nn.LayerNorm): - def __init__(self, channels, eps=1e-5): super().__init__(channels, eps, elementwise_affine=True) - def forward(self, x: Tensor): return self.__call__(x.transpose(1, -1)).transpose(1, -1) + def __init__(self, channels, eps=1e-5): + super().__init__(channels, eps, elementwise_affine=True) + + def forward(self, x: Tensor): + return self.__call__(x.transpose(1, -1)).transpose(1, -1) + class WN: - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): - assert (kernel_size % 2 == 1) - self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.gin_channels, self.p_dropout = hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout - self.in_layers, self.res_skip_layers = [], [] - if gin_channels != 0: self.cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) - for i in range(n_layers): - dilation = dilation_rate ** i - self.in_layers.append(nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=int((kernel_size * dilation - dilation) / 2))) - self.res_skip_layers.append(nn.Conv1d(hidden_channels, 2 * hidden_channels if i < n_layers - 1 else hidden_channels, 1)) - def forward(self, x, x_mask, g=None, **kwargs): - output = Tensor.zeros_like(x) - if g is not None: g = self.cond_layer(g) - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] - else: - g_l = Tensor.zeros_like(x_in) - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels) - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask - output = output + res_skip_acts[:, self.hidden_channels:, :] - else: - output = output + res_skip_acts - return output * x_mask + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + assert kernel_size % 2 == 1 + ( + self.hidden_channels, + self.kernel_size, + self.dilation_rate, + self.n_layers, + self.gin_channels, + self.p_dropout, + ) = ( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels, + p_dropout, + ) + self.in_layers, self.res_skip_layers = [], [] + if gin_channels != 0: + self.cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + for i in range(n_layers): + dilation = dilation_rate**i + self.in_layers.append( + nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=int((kernel_size * dilation - dilation) / 2), + ) + ) + self.res_skip_layers.append( + nn.Conv1d( + hidden_channels, + 2 * hidden_channels if i < n_layers - 1 else hidden_channels, + 1, + ) + ) + + def forward(self, x, x_mask, g=None, **kwargs): + output = Tensor.zeros_like(x) + if g is not None: + g = self.cond_layer(g) + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = Tensor.zeros_like(x_in) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels) + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + class ResBlock1: - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - self.convs1 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[i], padding=get_padding(kernel_size, dilation[i])) for i in range(3)] - self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)] - def forward(self, x: Tensor, x_mask=None): - for c1, c2 in zip(self.convs1, self.convs2): - xt = x.leakyrelu(LRELU_SLOPE) - xt = c1(xt if x_mask is None else xt * x_mask).leakyrelu(LRELU_SLOPE) - x = c2(xt if x_mask is None else xt * x_mask) + x - return x if x_mask is None else x * x_mask + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + self.convs1 = [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[i], + padding=get_padding(kernel_size, dilation[i]), + ) + for i in range(3) + ] + self.convs2 = [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + for _ in range(3) + ] + + def forward(self, x: Tensor, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = x.leakyrelu(LRELU_SLOPE) + xt = c1(xt if x_mask is None else xt * x_mask).leakyrelu(LRELU_SLOPE) + x = c2(xt if x_mask is None else xt * x_mask) + x + return x if x_mask is None else x * x_mask + class ResBlock2: - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - self.convs = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[i], padding=get_padding(kernel_size, dilation[i])) for i in range(2)] - def forward(self, x, x_mask=None): - for c in self.convs: - xt = x.leaky_relu(LRELU_SLOPE) - xt = c(xt if x_mask is None else xt * x_mask) - x = xt + x - return x if x_mask is None else x * x_mask + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + self.convs = [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[i], + padding=get_padding(kernel_size, dilation[i]), + ) + for i in range(2) + ] + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = x.leaky_relu(LRELU_SLOPE) + xt = c(xt if x_mask is None else xt * x_mask) + x = xt + x + return x if x_mask is None else x * x_mask + + +class DDSConv: # Dilated and Depth-Separable Convolution + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + self.channels, self.kernel_size, self.n_layers, self.p_dropout = ( + channels, + kernel_size, + n_layers, + p_dropout, + ) + self.convs_sep, self.convs_1x1, self.norms_1, self.norms_2 = [], [], [], [] + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i].forward(y).gelu() + y = self.convs_1x1[i](y) + y = self.norms_2[i].forward(y).gelu() + x = x + y.dropout(self.p_dropout) + return x * x_mask -class DDSConv: # Dilated and Depth-Separable Convolution - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - self.channels, self.kernel_size, self.n_layers, self.p_dropout = channels, kernel_size, n_layers, p_dropout - self.convs_sep, self.convs_1x1, self.norms_1, self.norms_2 = [], [], [], [] - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - def forward(self, x, x_mask, g=None): - if g is not None: x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i].forward(y).gelu() - y = self.convs_1x1[i](y) - y = self.norms_2[i].forward(y).gelu() - x = x + y.dropout(self.p_dropout) - return x * x_mask class ConvFlow: - def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): - self.in_channels, self.filter_channels, self.kernel_size, self.n_layers, self.num_bins, self.tail_bound = in_channels, filter_channels, kernel_size, n_layers, num_bins, tail_bound - self.half_channels = in_channels // 2 - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) - self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = split(x, [self.half_channels] * 2, 1) - h = self.proj(self.convs.forward(self.pre(x0), x_mask, g=g)) * x_mask - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - un_normalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) - un_normalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) - un_normalized_derivatives = h[..., 2 * self.num_bins:] - x1, log_abs_det = piecewise_rational_quadratic_transform(x1, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=reverse, tails='linear', tail_bound=self.tail_bound) - x = x0.cat(x1, dim=1) * x_mask - return x if reverse else (x, Tensor.sum(log_abs_det * x_mask, [1,2])) + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + ( + self.in_channels, + self.filter_channels, + self.kernel_size, + self.n_layers, + self.num_bins, + self.tail_bound, + ) = (in_channels, filter_channels, kernel_size, n_layers, num_bins, tail_bound) + self.half_channels = in_channels // 2 + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = split(x, [self.half_channels] * 2, 1) + h = self.proj(self.convs.forward(self.pre(x0), x_mask, g=g)) * x_mask + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + un_normalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + un_normalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + un_normalized_derivatives = h[..., 2 * self.num_bins :] + x1, log_abs_det = piecewise_rational_quadratic_transform( + x1, + un_normalized_widths, + un_normalized_heights, + un_normalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = x0.cat(x1, dim=1) * x_mask + return x if reverse else (x, Tensor.sum(log_abs_det * x_mask, [1, 2])) + class ResidualCouplingLayer: - def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False): - assert channels % 2 == 0, "channels should be divisible by 2" - self.channels, self.hidden_channels, self.kernel_size, self.dilation_rate, self.n_layers, self.mean_only = channels, hidden_channels, kernel_size, dilation_rate, n_layers, mean_only - self.half_channels = channels // 2 - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = split(x, [self.half_channels] * 2, 1) - stats = self.post(self.enc.forward(self.pre(x0) * x_mask, x_mask, g=g)) * x_mask - if not self.mean_only: - m, logs = split(stats, [self.half_channels] * 2, 1) - else: - m = stats - logs = Tensor.zeros_like(m) - if not reverse: return x0.cat((m + x1 * logs.exp() * x_mask), dim=1) - return x0.cat(((x1 - m) * (-logs).exp() * x_mask), dim=1) + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + ( + self.channels, + self.hidden_channels, + self.kernel_size, + self.dilation_rate, + self.n_layers, + self.mean_only, + ) = (channels, hidden_channels, kernel_size, dilation_rate, n_layers, mean_only) + self.half_channels = channels // 2 + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = split(x, [self.half_channels] * 2, 1) + stats = self.post(self.enc.forward(self.pre(x0) * x_mask, x_mask, g=g)) * x_mask + if not self.mean_only: + m, logs = split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = Tensor.zeros_like(m) + if not reverse: + return x0.cat((m + x1 * logs.exp() * x_mask), dim=1) + return x0.cat(((x1 - m) * (-logs).exp() * x_mask), dim=1) + class Log: - def forward(self, x : Tensor, x_mask, reverse=False): - if not reverse: - y = x.maximum(1e-5).log() * x_mask - return y, (-y).sum([1, 2]) - return x.exp() * x_mask + def forward(self, x: Tensor, x_mask, reverse=False): + if not reverse: + y = x.maximum(1e-5).log() * x_mask + return y, (-y).sum([1, 2]) + return x.exp() * x_mask + class Flip: - def forward(self, x: Tensor, *args, reverse=False, **kwargs): - return x.flip([1]) if reverse else (x.flip([1]), Tensor.zeros(x.shape[0], dtype=x.dtype).to(device=x.device)) + def forward(self, x: Tensor, *args, reverse=False, **kwargs): + return ( + x.flip([1]) + if reverse + else ( + x.flip([1]), + Tensor.zeros(x.shape[0], dtype=x.dtype).to(device=x.device), + ) + ) + class ElementwiseAffine: - def __init__(self, channels): self.m, self.logs = Tensor.zeros(channels, 1), Tensor.zeros(channels, 1) - def forward(self, x, x_mask, reverse=False, **kwargs): # x if reverse else y, logdet - return (x - self.m) * Tensor.exp(-self.logs) * x_mask if reverse \ - else ((self.m + Tensor.exp(self.logs) * x) * x_mask, Tensor.sum(self.logs * x_mask, [1, 2])) + def __init__(self, channels): + self.m, self.logs = Tensor.zeros(channels, 1), Tensor.zeros(channels, 1) + + def forward( + self, x, x_mask, reverse=False, **kwargs + ): # x if reverse else y, logdet + return ( + (x - self.m) * Tensor.exp(-self.logs) * x_mask + if reverse + else ( + (self.m + Tensor.exp(self.logs) * x) * x_mask, + Tensor.sum(self.logs * x_mask, [1, 2]), + ) + ) + class MultiHeadAttention: - def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): - assert channels % n_heads == 0 - self.channels, self.out_channels, self.n_heads, self.p_dropout, self.window_size, self.heads_share, self.block_length, self.proximal_bias, self.proximal_init = channels, out_channels, n_heads, p_dropout, window_size, heads_share, block_length, proximal_bias, proximal_init - self.attn, self.k_channels = None, channels // n_heads - self.conv_q, self.conv_k, self.conv_v = [nn.Conv1d(channels, channels, 1) for _ in range(3)] - self.conv_o = nn.Conv1d(channels, out_channels, 1) - if window_size is not None: self.emb_rel_k, self.emb_rel_v = [Tensor.randn(1 if heads_share else n_heads, window_size * 2 + 1, self.k_channels) * (self.k_channels ** -0.5) for _ in range(2)] - def forward(self, x, c, attn_mask=None): - q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c) - x, self.attn = self.attention(q, k, v, mask=attn_mask) - return self.conv_o(x) - def attention(self, query: Tensor, key: Tensor, value: Tensor, mask=None):# reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = key.shape[0], key.shape[1], key.shape[2], query.shape[2] - query = query.reshape(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - scores = (query / math.sqrt(self.k_channels)) @ key.transpose(-2, -1) - if self.window_size is not None: - assert t_s == t_t, "Relative attention is only available for self-attention." - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) - scores = scores + self._relative_position_to_absolute_position(rel_logits) - if mask is not None: - scores = Tensor.where(mask, scores, -1e4) - if self.block_length is not None: - assert t_s == t_t, "Local attention is only available for self-attention." - scores = Tensor.where(Tensor.ones_like(scores).triu(-self.block_length).tril(self.block_length), scores, -1e4) - p_attn = scores.softmax(axis=-1) # [b, n_h, t_t, t_s] - output = p_attn.matmul(value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().reshape(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] - return output, p_attn - def _matmul_with_relative_values(self, x, y): return x.matmul(y.unsqueeze(0)) # x: [b, h, l, m], y: [h or 1, m, d], ret: [b, h, l, d] - def _matmul_with_relative_keys(self, x, y): return x.matmul(y.unsqueeze(0).transpose(-2, -1)) # x: [b, h, l, d], y: [h or 1, m, d], re, : [b, h, l, m] - def _get_relative_embeddings(self, relative_embeddings, length): - pad_length, slice_start_position = max(length - (self.window_size + 1), 0), max((self.window_size + 1) - length, 0) - padded_relative_embeddings = relative_embeddings if pad_length <= 0\ - else relative_embeddings.pad(convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) - return padded_relative_embeddings[:, slice_start_position:(slice_start_position + 2 * length - 1)] #used_relative_embeddings - def _relative_position_to_absolute_position(self, x: Tensor): # x: [b, h, l, 2*l-1] -> [b, h, l, l] - batch, heads, length, _ = x.shape - x = x.pad(convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) - x_flat = x.reshape([batch, heads, length * 2 * length]).pad(convert_pad_shape([[0,0],[0,0],[0,length-1]])) - return x_flat.reshape([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] - def _absolute_position_to_relative_position(self, x: Tensor): # x: [b, h, l, l] -> [b, h, l, 2*l-1] - batch, heads, length, _ = x.shape - x = x.pad(convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) - x_flat = x.reshape([batch, heads, length**2 + length*(length -1)]).pad(convert_pad_shape([[0, 0], [0, 0], [length, 0]])) - return x_flat.reshape([batch, heads, length, 2*length])[:,:,:,1:] + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + assert channels % n_heads == 0 + ( + self.channels, + self.out_channels, + self.n_heads, + self.p_dropout, + self.window_size, + self.heads_share, + self.block_length, + self.proximal_bias, + self.proximal_init, + ) = ( + channels, + out_channels, + n_heads, + p_dropout, + window_size, + heads_share, + block_length, + proximal_bias, + proximal_init, + ) + self.attn, self.k_channels = None, channels // n_heads + self.conv_q, self.conv_k, self.conv_v = [ + nn.Conv1d(channels, channels, 1) for _ in range(3) + ] + self.conv_o = nn.Conv1d(channels, out_channels, 1) + if window_size is not None: + self.emb_rel_k, self.emb_rel_v = [ + Tensor.randn( + 1 if heads_share else n_heads, window_size * 2 + 1, self.k_channels + ) + * (self.k_channels**-0.5) + for _ in range(2) + ] + + def forward(self, x, c, attn_mask=None): + q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + return self.conv_o(x) + + def attention( + self, query: Tensor, key: Tensor, value: Tensor, mask=None + ): # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = key.shape[0], key.shape[1], key.shape[2], query.shape[2] + query = query.reshape(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.reshape(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + scores = (query / math.sqrt(self.k_channels)) @ key.transpose(-2, -1) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores = scores + self._relative_position_to_absolute_position(rel_logits) + if mask is not None: + scores = Tensor.where(mask, scores, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + scores = Tensor.where( + Tensor.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length), + scores, + -1e4, + ) + p_attn = scores.softmax(axis=-1) # [b, n_h, t_t, t_s] + output = p_attn.matmul(value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().reshape(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + return x.matmul( + y.unsqueeze(0) + ) # x: [b, h, l, m], y: [h or 1, m, d], ret: [b, h, l, d] + + def _matmul_with_relative_keys(self, x, y): + return x.matmul( + y.unsqueeze(0).transpose(-2, -1) + ) # x: [b, h, l, d], y: [h or 1, m, d], re, : [b, h, l, m] + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length, slice_start_position = max(length - (self.window_size + 1), 0), max( + (self.window_size + 1) - length, 0 + ) + padded_relative_embeddings = ( + relative_embeddings + if pad_length <= 0 + else relative_embeddings.pad( + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]) + ) + ) + return padded_relative_embeddings[ + :, slice_start_position : (slice_start_position + 2 * length - 1) + ] # used_relative_embeddings + + def _relative_position_to_absolute_position( + self, x: Tensor + ): # x: [b, h, l, 2*l-1] -> [b, h, l, l] + batch, heads, length, _ = x.shape + x = x.pad(convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + x_flat = x.reshape([batch, heads, length * 2 * length]).pad( + convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + return x_flat.reshape([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + + def _absolute_position_to_relative_position( + self, x: Tensor + ): # x: [b, h, l, l] -> [b, h, l, 2*l-1] + batch, heads, length, _ = x.shape + x = x.pad(convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.reshape([batch, heads, length**2 + length * (length - 1)]).pad( + convert_pad_shape([[0, 0], [0, 0], [length, 0]]) + ) + return x_flat.reshape([batch, heads, length, 2 * length])[:, :, :, 1:] + class FFN: - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): - self.in_channels, self.out_channels, self.filter_channels, self.kernel_size, self.p_dropout, self.activation, self.causal = in_channels, out_channels, filter_channels, kernel_size, p_dropout, activation, causal - self.padding = self._causal_padding if causal else self._same_padding - self.conv_1, self.conv_2 = nn.Conv1d(in_channels, filter_channels, kernel_size), nn.Conv1d(filter_channels, out_channels, kernel_size) - def forward(self, x, x_mask): - x = self.conv_1(self.padding(x * x_mask)) - x = x * (1.702 * x).sigmoid() if self.activation == "gelu" else x.relu() - return self.conv_2(self.padding(x.dropout(self.p_dropout) * x_mask)) * x_mask - def _causal_padding(self, x):return x if self.kernel_size == 1 else x.pad(convert_pad_shape([[0, 0], [0, 0], [self.kernel_size - 1, 0]])) - def _same_padding(self, x): return x if self.kernel_size == 1 else x.pad(convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1) // 2, self.kernel_size // 2]])) + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + ( + self.in_channels, + self.out_channels, + self.filter_channels, + self.kernel_size, + self.p_dropout, + self.activation, + self.causal, + ) = ( + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout, + activation, + causal, + ) + self.padding = self._causal_padding if causal else self._same_padding + self.conv_1, self.conv_2 = nn.Conv1d( + in_channels, filter_channels, kernel_size + ), nn.Conv1d(filter_channels, out_channels, kernel_size) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + x = x * (1.702 * x).sigmoid() if self.activation == "gelu" else x.relu() + return self.conv_2(self.padding(x.dropout(self.p_dropout) * x_mask)) * x_mask + + def _causal_padding(self, x): + return ( + x + if self.kernel_size == 1 + else x.pad(convert_pad_shape([[0, 0], [0, 0], [self.kernel_size - 1, 0]])) + ) + + def _same_padding(self, x): + return ( + x + if self.kernel_size == 1 + else x.pad( + convert_pad_shape( + [ + [0, 0], + [0, 0], + [(self.kernel_size - 1) // 2, self.kernel_size // 2], + ] + ) + ) + ) + class Encoder: - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): - self.hidden_channels, self.filter_channels, self.n_heads, self.n_layers, self.kernel_size, self.p_dropout, self.window_size = hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, window_size - self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 = [], [], [], [] - for _ in range(n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask): - attn_mask, x = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1), x * x_mask - for i in range(self.n_layers): - y = self.attn_layers[i].forward(x, x, attn_mask).dropout(self.p_dropout) - x = self.norm_layers_1[i].forward(x + y) - y = self.ffn_layers[i].forward(x, x_mask).dropout(self.p_dropout) - x = self.norm_layers_2[i].forward(x + y) - return x * x_mask + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs, + ): + ( + self.hidden_channels, + self.filter_channels, + self.n_heads, + self.n_layers, + self.kernel_size, + self.p_dropout, + self.window_size, + ) = ( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + window_size, + ) + self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 = ( + [], + [], + [], + [], + ) + for _ in range(n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask, x = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1), x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i].forward(x, x, attn_mask).dropout(self.p_dropout) + x = self.norm_layers_1[i].forward(x + y) + y = self.ffn_layers[i].forward(x, x_mask).dropout(self.p_dropout) + x = self.norm_layers_2[i].forward(x + y) + return x * x_mask + DEFAULT_MIN_BIN_WIDTH, DEFAULT_MIN_BIN_HEIGHT, DEFAULT_MIN_DERIVATIVE = 1e-3, 1e-3, 1e-3 -def piecewise_rational_quadratic_transform(inputs, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=False, tails=None, tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE): - if tails is None: spline_fn, spline_kwargs = rational_quadratic_spline, {} - else: spline_fn, spline_kwargs = unconstrained_rational_quadratic_spline, {'tails': tails, 'tail_bound': tail_bound} - return spline_fn(inputs=inputs, un_normalized_widths=un_normalized_widths, un_normalized_heights=un_normalized_heights, un_normalized_derivatives=un_normalized_derivatives, inverse=inverse, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, **spline_kwargs) -def unconstrained_rational_quadratic_spline(inputs, un_normalized_widths, un_normalized_heights, un_normalized_derivatives, inverse=False, tails='linear', tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE): - if not tails == 'linear': raise RuntimeError('{} tails are not implemented.'.format(tails)) - constant = np.log(np.exp(1 - min_derivative) - 1) - un_normalized_derivatives = cat_lr(un_normalized_derivatives, constant, constant) - output, log_abs_det = rational_quadratic_spline(inputs=inputs.squeeze(dim=0).squeeze(dim=0), unnormalized_widths=un_normalized_widths.squeeze(dim=0).squeeze(dim=0), unnormalized_heights=un_normalized_heights.squeeze(dim=0).squeeze(dim=0), unnormalized_derivatives=un_normalized_derivatives.squeeze(dim=0).squeeze(dim=0), inverse=inverse, left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative) - return output.unsqueeze(dim=0).unsqueeze(dim=0), log_abs_det.unsqueeze(dim=0).unsqueeze(dim=0) -def rational_quadratic_spline(inputs: Tensor, unnormalized_widths: Tensor, unnormalized_heights: Tensor, unnormalized_derivatives: Tensor, inverse=False, left=0., right=1., bottom=0., top=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_derivative=DEFAULT_MIN_DERIVATIVE): - num_bins = unnormalized_widths.shape[-1] - if min_bin_width * num_bins > 1.0: raise ValueError('Minimal bin width too large for the number of bins') - if min_bin_height * num_bins > 1.0: raise ValueError('Minimal bin height too large for the number of bins') - widths = min_bin_width + (1 - min_bin_width * num_bins) * unnormalized_widths.softmax(axis=-1) - cum_widths = cat_lr(((right - left) * widths[..., :-1].cumsum(axis=1) + left), left, right + 1e-6 if not inverse else right) - widths = cum_widths[..., 1:] - cum_widths[..., :-1] - derivatives = min_derivative + (unnormalized_derivatives.exp()+1).log() - heights = min_bin_height + (1 - min_bin_height * num_bins) * unnormalized_heights.softmax(axis=-1) - cum_heights = cat_lr(((top - bottom) * heights[..., :-1].cumsum(axis=1) + bottom), bottom, top + 1e-6 if inverse else top) - heights = cum_heights[..., 1:] - cum_heights[..., :-1] - bin_idx = ((inputs[..., None] >= (cum_heights if inverse else cum_widths)).sum(axis=-1) - 1)[..., None] - input_cum_widths = gather(cum_widths, bin_idx, axis=-1)[..., 0] - input_bin_widths = gather(widths, bin_idx, axis=-1)[..., 0] - input_cum_heights = gather(cum_heights, bin_idx, axis=-1)[..., 0] - input_delta = gather(heights / widths, bin_idx, axis=-1)[..., 0] - input_derivatives = gather(derivatives, bin_idx, axis=-1)[..., 0] - input_derivatives_plus_one = gather(derivatives[..., 1:], bin_idx, axis=-1)[..., 0] - input_heights = gather(heights, bin_idx, axis=-1)[..., 0] - if inverse: - a = ((inputs - input_cum_heights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives)) - b = (input_heights * input_derivatives - (inputs - input_cum_heights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta)) - c = - input_delta * (inputs - input_cum_heights) - discriminant = b.square() - 4 * a * c - # assert (discriminant.numpy() >= 0).all() - root = (2 * c) / (-b - discriminant.sqrt()) - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) - derivative_numerator = input_delta.square() * (input_derivatives_plus_one * root.square() + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).square()) - return root * input_bin_widths + input_cum_widths, -(derivative_numerator.log() - 2 * denominator.log()) - theta = (inputs - input_cum_widths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2)) - return input_cum_heights + numerator / denominator, derivative_numerator.log() - 2 * denominator.log() -def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1) -def generate_path(duration: Tensor, mask: Tensor): # duration: [b, 1, t_x], mask: [b, 1, t_y, t_x] - b, _, t_y, t_x = mask.shape - path = sequence_mask(duration.cumsum(axis=2).reshape(b * t_x), t_y).cast(mask.dtype).reshape(b, t_x, t_y) - path = path - path.pad(convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - return path.unsqueeze(1).transpose(2, 3) * mask + +def piecewise_rational_quadratic_transform( + inputs, + un_normalized_widths, + un_normalized_heights, + un_normalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn, spline_kwargs = rational_quadratic_spline, {} + else: + spline_fn, spline_kwargs = unconstrained_rational_quadratic_spline, { + "tails": tails, + "tail_bound": tail_bound, + } + return spline_fn( + inputs=inputs, + un_normalized_widths=un_normalized_widths, + un_normalized_heights=un_normalized_heights, + un_normalized_derivatives=un_normalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + + +def unconstrained_rational_quadratic_spline( + inputs, + un_normalized_widths, + un_normalized_heights, + un_normalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if not tails == "linear": + raise RuntimeError("{} tails are not implemented.".format(tails)) + constant = np.log(np.exp(1 - min_derivative) - 1) + un_normalized_derivatives = cat_lr(un_normalized_derivatives, constant, constant) + output, log_abs_det = rational_quadratic_spline( + inputs=inputs.squeeze(dim=0).squeeze(dim=0), + unnormalized_widths=un_normalized_widths.squeeze(dim=0).squeeze(dim=0), + unnormalized_heights=un_normalized_heights.squeeze(dim=0).squeeze(dim=0), + unnormalized_derivatives=un_normalized_derivatives.squeeze(dim=0).squeeze( + dim=0 + ), + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + return output.unsqueeze(dim=0).unsqueeze(dim=0), log_abs_det.unsqueeze( + dim=0 + ).unsqueeze(dim=0) + + +def rational_quadratic_spline( + inputs: Tensor, + unnormalized_widths: Tensor, + unnormalized_heights: Tensor, + unnormalized_derivatives: Tensor, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + num_bins = unnormalized_widths.shape[-1] + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + widths = min_bin_width + ( + 1 - min_bin_width * num_bins + ) * unnormalized_widths.softmax(axis=-1) + cum_widths = cat_lr( + ((right - left) * widths[..., :-1].cumsum(axis=1) + left), + left, + right + 1e-6 if not inverse else right, + ) + widths = cum_widths[..., 1:] - cum_widths[..., :-1] + derivatives = min_derivative + (unnormalized_derivatives.exp() + 1).log() + heights = min_bin_height + ( + 1 - min_bin_height * num_bins + ) * unnormalized_heights.softmax(axis=-1) + cum_heights = cat_lr( + ((top - bottom) * heights[..., :-1].cumsum(axis=1) + bottom), + bottom, + top + 1e-6 if inverse else top, + ) + heights = cum_heights[..., 1:] - cum_heights[..., :-1] + bin_idx = ( + (inputs[..., None] >= (cum_heights if inverse else cum_widths)).sum(axis=-1) - 1 + )[..., None] + input_cum_widths = gather(cum_widths, bin_idx, axis=-1)[..., 0] + input_bin_widths = gather(widths, bin_idx, axis=-1)[..., 0] + input_cum_heights = gather(cum_heights, bin_idx, axis=-1)[..., 0] + input_delta = gather(heights / widths, bin_idx, axis=-1)[..., 0] + input_derivatives = gather(derivatives, bin_idx, axis=-1)[..., 0] + input_derivatives_plus_one = gather(derivatives[..., 1:], bin_idx, axis=-1)[..., 0] + input_heights = gather(heights, bin_idx, axis=-1)[..., 0] + if inverse: + a = (inputs - input_cum_heights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cum_heights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cum_heights) + discriminant = b.square() - 4 * a * c + # assert (discriminant.numpy() >= 0).all() + root = (2 * c) / (-b - discriminant.sqrt()) + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.square() * ( + input_derivatives_plus_one * root.square() + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).square() + ) + return root * input_bin_widths + input_cum_widths, -( + derivative_numerator.log() - 2 * denominator.log() + ) + theta = (inputs - input_cum_widths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + return ( + input_cum_heights + numerator / denominator, + derivative_numerator.log() - 2 * denominator.log(), + ) + + +def sequence_mask(length: Tensor, max_length): + return Tensor.arange( + max_length, dtype=length.dtype, device=length.device + ).unsqueeze(0) < length.unsqueeze(1) + + +def generate_path( + duration: Tensor, mask: Tensor +): # duration: [b, 1, t_x], mask: [b, 1, t_y, t_x] + b, _, t_y, t_x = mask.shape + path = ( + sequence_mask(duration.cumsum(axis=2).reshape(b * t_x), t_y) + .cast(mask.dtype) + .reshape(b, t_x, t_y) + ) + path = path - path.pad(convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + return path.unsqueeze(1).transpose(2, 3) * mask + + def fused_add_tanh_sigmoid_multiply(input_a: Tensor, input_b: Tensor, n_channels: int): - n_channels_int, in_act = n_channels, input_a + input_b - t_act, s_act = in_act[:, :n_channels_int, :].tanh(), in_act[:, n_channels_int:, :].sigmoid() - return t_act * s_act + n_channels_int, in_act = n_channels, input_a + input_b + t_act, s_act = ( + in_act[:, :n_channels_int, :].tanh(), + in_act[:, n_channels_int:, :].sigmoid(), + ) + return t_act * s_act + + +def cat_lr(t, left, right): + return ( + Tensor.full(get_shape(t), left) + .cat(t, dim=-1) + .cat(Tensor.full(get_shape(t), right), dim=-1) + ) + -def cat_lr(t, left, right): return Tensor.full(get_shape(t), left).cat(t, dim=-1).cat(Tensor.full(get_shape(t), right), dim=-1) def get_shape(tensor): - (shape := list(tensor.shape))[-1] = 1 - return tuple(shape) -def convert_pad_shape(pad_shape): return tuple(tuple(x) for x in pad_shape) -def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) -def split(tensor, split_sizes, dim=0): # if split_sizes is an integer, convert it to a tuple of size split_sizes elements - if isinstance(split_sizes, int): split_sizes = (split_sizes,) * (tensor.shape[dim] // split_sizes) - assert sum(split_sizes) == tensor.shape[ - dim], "Sum of split_sizes must equal the dimension size of tensor along the given dimension." - start, slices = 0, [] - for size in split_sizes: - slice_range = [(start, start + size) if j == dim else None for j in range(len(tensor.shape))] - slices.append(slice_range) - start += size - return [tensor.slice(s) for s in slices] + (shape := list(tensor.shape))[-1] = 1 + return tuple(shape) + + +def convert_pad_shape(pad_shape): + return tuple(tuple(x) for x in pad_shape) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def split( + tensor, split_sizes, dim=0 +): # if split_sizes is an integer, convert it to a tuple of size split_sizes elements + if isinstance(split_sizes, int): + split_sizes = (split_sizes,) * (tensor.shape[dim] // split_sizes) + assert ( + sum(split_sizes) == tensor.shape[dim] + ), "Sum of split_sizes must equal the dimension size of tensor along the given dimension." + start, slices = 0, [] + for size in split_sizes: + slice_range = [ + (start, start + size) if j == dim else None + for j in range(len(tensor.shape)) + ] + slices.append(slice_range) + start += size + return [tensor.slice(s) for s in slices] + + def gather(x, indices, axis): - indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(ax1=axis, ax2=0) - permute_args = list(range(x.ndim)) - permute_args[0], permute_args[axis] = permute_args[axis], permute_args[0] - permute_args.append(permute_args.pop(0)) - x = x.permute(*permute_args) - reshape_arg = [1] * x.ndim + [x.shape[-1]] - return ((indices.unsqueeze(indices.ndim).expand(*indices.shape, x.shape[-1]) == - Tensor.arange(x.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, x.shape[-1])) * x).sum(indices.ndim).transpose(ax1=0, ax2=axis) + indices = ( + (indices < 0).where(indices + x.shape[axis], indices).transpose(ax1=axis, ax2=0) + ) + permute_args = list(range(x.ndim)) + permute_args[0], permute_args[axis] = permute_args[axis], permute_args[0] + permute_args.append(permute_args.pop(0)) + x = x.permute(*permute_args) + reshape_arg = [1] * x.ndim + [x.shape[-1]] + return ( + ( + ( + indices.unsqueeze(indices.ndim).expand(*indices.shape, x.shape[-1]) + == Tensor.arange(x.shape[-1]) + .reshape(*reshape_arg) + .expand(*indices.shape, x.shape[-1]) + ) + * x + ) + .sum(indices.ndim) + .transpose(ax1=0, ax2=axis) + ) + def norm_except_dim(v, dim): - if dim == -1: return np.linalg.norm(v) - if dim == 0: - (output_shape := [1] * v.ndim)[0] = v.shape[0] - return np.linalg.norm(v.reshape(v.shape[0], -1), axis=1).reshape(output_shape) - if dim == v.ndim - 1: - (output_shape := [1] * v.ndim)[-1] = v.shape[-1] - return np.linalg.norm(v.reshape(-1, v.shape[-1]), axis=0).reshape(output_shape) - transposed_v = np.transpose(v, (dim,) + tuple(i for i in range(v.ndim) if i != dim)) - return np.transpose(norm_except_dim(transposed_v, 0), (dim,) + tuple(i for i in range(v.ndim) if i != dim)) + if dim == -1: + return np.linalg.norm(v) + if dim == 0: + (output_shape := [1] * v.ndim)[0] = v.shape[0] + return np.linalg.norm(v.reshape(v.shape[0], -1), axis=1).reshape(output_shape) + if dim == v.ndim - 1: + (output_shape := [1] * v.ndim)[-1] = v.shape[-1] + return np.linalg.norm(v.reshape(-1, v.shape[-1]), axis=0).reshape(output_shape) + transposed_v = np.transpose(v, (dim,) + tuple(i for i in range(v.ndim) if i != dim)) + return np.transpose( + norm_except_dim(transposed_v, 0), + (dim,) + tuple(i for i in range(v.ndim) if i != dim), + ) + + def weight_norm(v: Tensor, g: Tensor, dim): - v, g = v.numpy(), g.numpy() - return Tensor(v * (g / norm_except_dim(v, dim))) + v, g = v.numpy(), g.numpy() + return Tensor(v * (g / norm_except_dim(v, dim))) + # HPARAMS LOADING def get_hparams_from_file(path): - with open(path, "r") as f: - data = f.read() - return HParams(**json.loads(data)) + with open(path, "r") as f: + data = f.read() + return HParams(**json.loads(data)) + + class HParams: - def __init__(self, **kwargs): - for k, v in kwargs.items(): self[k] = v if type(v) != dict else HParams(**v) - def keys(self): return self.__dict__.keys() - def items(self): return self.__dict__.items() - def values(self): return self.__dict__.values() - def __len__(self): return len(self.__dict__) - def __getitem__(self, key): return getattr(self, key) - def __setitem__(self, key, value): return setattr(self, key, value) - def __contains__(self, key): return key in self.__dict__ - def __repr__(self): return self.__dict__.__repr__() + def __init__(self, **kwargs): + for k, v in kwargs.items(): + self[k] = v if type(v) != dict else HParams(**v) + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + # MODEL LOADING def load_model(symbols, hps, model) -> Synthesizer: - net_g = Synthesizer(len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers = hps.data.n_speakers, **hps.model) - _ = load_checkpoint(fetch(model[1]), net_g, None) - return net_g + net_g = Synthesizer( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + _ = load_checkpoint(fetch(model[1]), net_g, None) + return net_g + + def load_checkpoint(checkpoint_path, model: Synthesizer, optimizer=None, skip_list=[]): - assert Path(checkpoint_path).is_file() - start_time = time.time() - checkpoint_dict = torch_load(checkpoint_path) - iteration, learning_rate = checkpoint_dict['iteration'], checkpoint_dict['learning_rate'] - if optimizer: optimizer.load_state_dict(checkpoint_dict['optimizer']) - saved_state_dict = checkpoint_dict['model'] - weight_g, weight_v, parent = None, None, None - for key, v in saved_state_dict.items(): - if any(layer in key for layer in skip_list): continue - try: - obj, skip = model, False - for k in key.split('.'): - if k.isnumeric(): obj = obj[int(k)] - elif isinstance(obj, dict): obj = obj[k] - else: - if isinstance(obj, (LayerNorm, nn.LayerNorm)) and k in ["gamma", "beta"]: - k = "weight" if k == "gamma" else "bias" - elif k in ["weight_g", "weight_v"]: - parent, skip = obj, True - if k == "weight_g": weight_g = v - else: weight_v = v - if not skip: obj = getattr(obj, k) - if weight_g and weight_v: - setattr(obj, "weight_g", weight_g.numpy()) - setattr(obj, "weight_v", weight_v.numpy()) - obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0) - weight_g, weight_v, parent, skip = None, None, None, False - if not skip and obj.shape == v.shape: obj.assign(v.to(obj.device)) - elif not skip: logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}") - except Exception as e: raise e - logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration}) in {time.time() - start_time:.4f}s") - return model, optimizer, learning_rate, iteration + assert Path(checkpoint_path).is_file() + start_time = time.time() + checkpoint_dict = torch_load(checkpoint_path) + iteration, learning_rate = ( + checkpoint_dict["iteration"], + checkpoint_dict["learning_rate"], + ) + if optimizer: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + weight_g, weight_v, parent = None, None, None + for key, v in saved_state_dict.items(): + if any(layer in key for layer in skip_list): + continue + try: + obj, skip = model, False + for k in key.split("."): + if k.isnumeric(): + obj = obj[int(k)] + elif isinstance(obj, dict): + obj = obj[k] + else: + if isinstance(obj, (LayerNorm, nn.LayerNorm)) and k in [ + "gamma", + "beta", + ]: + k = "weight" if k == "gamma" else "bias" + elif k in ["weight_g", "weight_v"]: + parent, skip = obj, True + if k == "weight_g": + weight_g = v + else: + weight_v = v + if not skip: + obj = getattr(obj, k) + if weight_g and weight_v: + setattr(obj, "weight_g", weight_g.numpy()) + setattr(obj, "weight_v", weight_v.numpy()) + obj, v = getattr(parent, "weight"), weight_norm(weight_v, weight_g, 0) + weight_g, weight_v, parent, skip = None, None, None, False + if not skip and obj.shape == v.shape: + obj.assign(v.to(obj.device)) + elif not skip: + logging.error(f"MISMATCH SHAPE IN {key}, {obj.shape} {v.shape}") + except Exception as e: + raise e + logging.info( + f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration}) in {time.time() - start_time:.4f}s" + ) + return model, optimizer, learning_rate, iteration + # Used for cleaning input text and mapping to symbols -class TextMapper: # Based on https://github.com/keithito/tacotron - def __init__(self, symbols, apply_cleaners=True): - self.apply_cleaners, self.symbols, self._inflect = apply_cleaners, symbols, None - self._symbol_to_id, _id_to_symbol = {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)} - self._whitespace_re, self._abbreviations = re.compile(r'\s+'), [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]] - self.phonemizer = EspeakBackend( - language="en-us", punctuation_marks=Punctuation.default_marks(), preserve_punctuation=True, with_stress=True, - ) - def text_to_sequence(self, text, cleaner_names): - if self.apply_cleaners: - for name in cleaner_names: - cleaner = getattr(self, name) - if not cleaner: raise ModuleNotFoundError('Unknown cleaner: %s' % name) - text = cleaner(text) - else: text = text.strip() - return [self._symbol_to_id[symbol] for symbol in text] - def get_text(self, text, add_blank=False, cleaners=('english_cleaners2',)): - text_norm = self.text_to_sequence(text, cleaners) - return Tensor(self.intersperse(text_norm, 0) if add_blank else text_norm, dtype=dtypes.int64) - def intersperse(self, lst, item): - (result := [item] * (len(lst) * 2 + 1))[1::2] = lst - return result - def phonemize(self, text, strip=True): return _phonemize(self.phonemizer, text, default_separator, strip, 1, False, False) - def filter_oov(self, text): return "".join(list(filter(lambda x: x in self._symbol_to_id, text))) - def base_english_cleaners(self, text): return self.collapse_whitespace(self.phonemize(self.expand_abbreviations(unidecode(text.lower())))) - def english_cleaners2(self, text): return self.base_english_cleaners(text) - def transliteration_cleaners(self, text): return self.collapse_whitespace(unidecode(text.lower())) - def cjke_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text).replace('ɑ', 'a').replace('ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u'))) - def cjke_cleaners2(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text))) - def cjks_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_lazy_ipa(text))) - def english_to_ipa2(self, text): - _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ ('r', 'ɹ'), ('ʤ', 'dʒ'), ('ʧ', 'tʃ')]] - return reduce(lambda t, rx: re.sub(rx[0], rx[1], t), _ipa_to_ipa2, self.mark_dark_l(self.english_to_ipa(text))).replace('...', '…') - def mark_dark_l(self, text): return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ' + x.group(1), text) - def english_to_ipa(self, text): - import eng_to_ipa as ipa - return self.collapse_whitespace(ipa.convert(self.normalize_numbers(self.expand_abbreviations(unidecode(text).lower())))) - def english_to_lazy_ipa(self, text): - _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [('r', 'ɹ'), ('æ', 'e'), ('ɑ', 'a'), ('ɔ', 'o'), ('ð', 'z'), ('θ', 's'), ('ɛ', 'e'), ('ɪ', 'i'), ('ʊ', 'u'), ('ʒ', 'ʥ'), ('ʤ', 'ʥ'), ('ˈ', '↓')]] - return reduce(lambda t, rx: re.sub(rx[0], rx[1], t), _lazy_ipa, self.english_to_ipa(text)) - def expand_abbreviations(self, text): return reduce(lambda t, abbr: re.sub(abbr[0], abbr[1], t), self._abbreviations, text) - def collapse_whitespace(self, text): return re.sub(self._whitespace_re, ' ', text) - def normalize_numbers(self, text): - import inflect - self._inflect = inflect.engine() - text = re.sub(re.compile(r'([0-9][0-9\,]+[0-9])'), self._remove_commas, text) - text = re.sub(re.compile(r'£([0-9\,]*[0-9]+)'), r'\1 pounds', text) - text = re.sub(re.compile(r'\$([0-9\.\,]*[0-9]+)'), self._expand_dollars, text) - text = re.sub(re.compile(r'([0-9]+\.[0-9]+)'), self._expand_decimal_point, text) - text = re.sub(re.compile(r'[0-9]+(st|nd|rd|th)'), self._expand_ordinal, text) - text = re.sub(re.compile(r'[0-9]+'), self._expand_number, text) - return text - def _remove_commas(self, m): return m.group(1).replace(',', '') # george won't like this - def _expand_dollars(self, m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: return match + ' dollars' # Unexpected format - dollars, cents = int(parts[0]) if parts[0] else 0, int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: return '%s %s, %s %s' % (dollars, 'dollar' if dollars == 1 else 'dollars', cents, 'cent' if cents == 1 else 'cents') - if dollars: return '%s %s' % (dollars, 'dollar' if dollars == 1 else 'dollars') - if cents: return '%s %s' % (cents, 'cent' if cents == 1 else 'cents') - return 'zero dollars' - def _expand_decimal_point(self, m): return m.group(1).replace('.', ' point ') - def _expand_ordinal(self, m): return self._inflect.number_to_words(m.group(0)) - def _expand_number(self, _inflect, m): - num = int(m.group(0)) - if 1000 < num < 3000: - if num == 2000: return 'two thousand' - if 2000 < num < 2010: return 'two thousand ' + self._inflect.number_to_words(num % 100) - if num % 100 == 0: return self._inflect.number_to_words(num // 100) + ' hundred' - return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') - return self._inflect.number_to_words(num, andword='') +class TextMapper: # Based on https://github.com/keithito/tacotron + def __init__(self, symbols, apply_cleaners=True): + self.apply_cleaners, self.symbols, self._inflect = apply_cleaners, symbols, None + self._symbol_to_id, _id_to_symbol = {s: i for i, s in enumerate(symbols)}, { + i: s for i, s in enumerate(symbols) + } + self._whitespace_re, self._abbreviations = re.compile(r"\s+"), [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] + ] + self.phonemizer = EspeakBackend( + language="en-us", + punctuation_marks=Punctuation.default_marks(), + preserve_punctuation=True, + with_stress=True, + ) + + def text_to_sequence(self, text, cleaner_names): + if self.apply_cleaners: + for name in cleaner_names: + cleaner = getattr(self, name) + if not cleaner: + raise ModuleNotFoundError("Unknown cleaner: %s" % name) + text = cleaner(text) + else: + text = text.strip() + return [self._symbol_to_id[symbol] for symbol in text] + + def get_text(self, text, add_blank=False, cleaners=("english_cleaners2",)): + text_norm = self.text_to_sequence(text, cleaners) + return Tensor( + self.intersperse(text_norm, 0) if add_blank else text_norm, + dtype=dtypes.int64, + ) + + def intersperse(self, lst, item): + (result := [item] * (len(lst) * 2 + 1))[1::2] = lst + return result + + def phonemize(self, text, strip=True): + return _phonemize( + self.phonemizer, text, default_separator, strip, 1, False, False + ) + + def filter_oov(self, text): + return "".join(list(filter(lambda x: x in self._symbol_to_id, text))) + + def base_english_cleaners(self, text): + return self.collapse_whitespace( + self.phonemize(self.expand_abbreviations(unidecode(text.lower()))) + ) + + def english_cleaners2(self, text): + return self.base_english_cleaners(text) + + def transliteration_cleaners(self, text): + return self.collapse_whitespace(unidecode(text.lower())) + + def cjke_cleaners(self, text): + return re.sub( + r"([^\.,!\?\-…~])$", + r"\1.", + re.sub( + r"\s+$", + "", + self.english_to_ipa2(text) + .replace("ɑ", "a") + .replace("ɔ", "o") + .replace("ɛ", "e") + .replace("ɪ", "i") + .replace("ʊ", "u"), + ), + ) + + def cjke_cleaners2(self, text): + return re.sub( + r"([^\.,!\?\-…~])$", r"\1.", re.sub(r"\s+$", "", self.english_to_ipa2(text)) + ) + + def cjks_cleaners(self, text): + return re.sub( + r"([^\.,!\?\-…~])$", + r"\1.", + re.sub(r"\s+$", "", self.english_to_lazy_ipa(text)), + ) + + def english_to_ipa2(self, text): + _ipa_to_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) + for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")] + ] + return reduce( + lambda t, rx: re.sub(rx[0], rx[1], t), + _ipa_to_ipa2, + self.mark_dark_l(self.english_to_ipa(text)), + ).replace("...", "…") + + def mark_dark_l(self, text): + return re.sub(r"l([^aeiouæɑɔəɛɪʊ ]*(?: |$))", lambda x: "ɫ" + x.group(1), text) + + def english_to_ipa(self, text): + import eng_to_ipa as ipa + + return self.collapse_whitespace( + ipa.convert( + self.normalize_numbers( + self.expand_abbreviations(unidecode(text).lower()) + ) + ) + ) + + def english_to_lazy_ipa(self, text): + _lazy_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("r", "ɹ"), + ("æ", "e"), + ("ɑ", "a"), + ("ɔ", "o"), + ("ð", "z"), + ("θ", "s"), + ("ɛ", "e"), + ("ɪ", "i"), + ("ʊ", "u"), + ("ʒ", "ʥ"), + ("ʤ", "ʥ"), + ("ˈ", "↓"), + ] + ] + return reduce( + lambda t, rx: re.sub(rx[0], rx[1], t), _lazy_ipa, self.english_to_ipa(text) + ) + + def expand_abbreviations(self, text): + return reduce( + lambda t, abbr: re.sub(abbr[0], abbr[1], t), self._abbreviations, text + ) + + def collapse_whitespace(self, text): + return re.sub(self._whitespace_re, " ", text) + + def normalize_numbers(self, text): + import inflect + + self._inflect = inflect.engine() + text = re.sub(re.compile(r"([0-9][0-9\,]+[0-9])"), self._remove_commas, text) + text = re.sub(re.compile(r"£([0-9\,]*[0-9]+)"), r"\1 pounds", text) + text = re.sub(re.compile(r"\$([0-9\.\,]*[0-9]+)"), self._expand_dollars, text) + text = re.sub(re.compile(r"([0-9]+\.[0-9]+)"), self._expand_decimal_point, text) + text = re.sub(re.compile(r"[0-9]+(st|nd|rd|th)"), self._expand_ordinal, text) + text = re.sub(re.compile(r"[0-9]+"), self._expand_number, text) + return text + + def _remove_commas(self, m): + return m.group(1).replace(",", "") # george won't like this + + def _expand_dollars(self, m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars, cents = ( + int(parts[0]) if parts[0] else 0, + int(parts[1]) if len(parts) > 1 and parts[1] else 0, + ) + if dollars and cents: + return "%s %s, %s %s" % ( + dollars, + "dollar" if dollars == 1 else "dollars", + cents, + "cent" if cents == 1 else "cents", + ) + if dollars: + return "%s %s" % (dollars, "dollar" if dollars == 1 else "dollars") + if cents: + return "%s %s" % (cents, "cent" if cents == 1 else "cents") + return "zero dollars" + + def _expand_decimal_point(self, m): + return m.group(1).replace(".", " point ") + + def _expand_ordinal(self, m): + return self._inflect.number_to_words(m.group(0)) + + def _expand_number(self, _inflect, m): + num = int(m.group(0)) + if 1000 < num < 3000: + if num == 2000: + return "two thousand" + if 2000 < num < 2010: + return "two thousand " + self._inflect.number_to_words(num % 100) + if num % 100 == 0: + return self._inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") + return self._inflect.number_to_words(num, andword="") + ######################################################################################### # PAPER: https://arxiv.org/abs/2106.06103 @@ -657,93 +1801,234 @@ class TextMapper: # Based on https://github.com/keithito/tacotron # anime lady 2 | --model_to_use uma_trilingual --speaker_id 121 ######################################################################################### VITS_PATH = Path(__file__).parents[1] / "weights/VITS/" -MODELS = { # config_url, weights_url - "ljs": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t"), - "vctk": ("https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t"), - "mmts-tts": ("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth"), - "uma_trilingual": ("https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth"), - "cjks": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth"), - "voistock": ("https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth"), +MODELS = { # config_url, weights_url + "ljs": ( + "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/ljs_base.json", + "https://drive.google.com/uc?export=download&id=1q86w74Ygw2hNzYP9cWkeClGT5X25PvBT&confirm=t", + ), + "vctk": ( + "https://raw.githubusercontent.com/jaywalnut310/vits/main/configs/vctk_base.json", + "https://drive.google.com/uc?export=download&id=11aHOlhnxzjpdWDpsz1vFDCzbeEfoIxru&confirm=t", + ), + "mmts-tts": ( + "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/config.json", + "https://huggingface.co/facebook/mms-tts/resolve/main/full_models/eng/G_100000.pth", + ), + "uma_trilingual": ( + "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/raw/main/configs/uma_trilingual.json", + "https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth", + ), + "cjks": ( + "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/config.json", + "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/14/model.pth", + ), + "voistock": ( + "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/config.json", + "https://huggingface.co/spaces/skytnt/moe-tts/resolve/main/saved_model/15/model.pth", + ), } -Y_LENGTH_ESTIMATE_SCALARS = {"ljs": 2.8, "vctk": 1.74, "mmts-tts": 1.9, "uma_trilingual": 2.3, "cjks": 3.3, "voistock": 3.1} -if __name__ == '__main__': - logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) - parser = argparse.ArgumentParser() - parser.add_argument("--model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.") - parser.add_argument("--speaker_id", type=int, default=6, help="Specify the speaker ID. Default is 6.") - parser.add_argument("--out_path", default=None, help="Specify the full output path. Overrides the --out_dir and --name parameter.") - parser.add_argument("--out_dir", default=str(Path(__file__).parents[1] / "temp"), help="Specify the output path.") - parser.add_argument("--base_name", default="test", help="Specify the base of the output file name. Default is 'test'.") - parser.add_argument("--text_to_synthesize", default="""Hello person. If the code you are contributing isn't some of the highest quality code you've written in your life, either put in the effort to make it great, or don't bother.""", help="Specify the text to synthesize. Default is a greeting message.") - parser.add_argument("--noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.") - parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.") - parser.add_argument("--length_scale", type=float, default=1, help="Specify the length scale. Default is 1.") - parser.add_argument("--seed", type=int, default=1337, help="Specify the seed (set to None if no seed). Default is 1337.") - parser.add_argument("--num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.") - parser.add_argument("--sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.") - parser.add_argument("--emotion_path", type=str, default=None, help="Specify the path to emotion reference.") - parser.add_argument("--estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.") - args = parser.parse_args() +Y_LENGTH_ESTIMATE_SCALARS = { + "ljs": 2.8, + "vctk": 1.74, + "mmts-tts": 1.9, + "uma_trilingual": 2.3, + "cjks": 3.3, + "voistock": 3.1, +} +if __name__ == "__main__": + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_to_use", + default="vctk", + help="Specify the model to use. Default is 'vctk'.", + ) + parser.add_argument( + "--speaker_id", + type=int, + default=6, + help="Specify the speaker ID. Default is 6.", + ) + parser.add_argument( + "--out_path", + default=None, + help="Specify the full output path. Overrides the --out_dir and --name parameter.", + ) + parser.add_argument( + "--out_dir", + default=str(Path(__file__).parents[1] / "temp"), + help="Specify the output path.", + ) + parser.add_argument( + "--base_name", + default="test", + help="Specify the base of the output file name. Default is 'test'.", + ) + parser.add_argument( + "--text_to_synthesize", + default="""Hello person. If the code you are contributing isn't some of the highest quality code you've written in your life, either put in the effort to make it great, or don't bother.""", + help="Specify the text to synthesize. Default is a greeting message.", + ) + parser.add_argument( + "--noise_scale", + type=float, + default=0.667, + help="Specify the noise scale. Default is 0.667.", + ) + parser.add_argument( + "--noise_scale_w", + type=float, + default=0.8, + help="Specify the noise scale w. Default is 0.8.", + ) + parser.add_argument( + "--length_scale", + type=float, + default=1, + help="Specify the length scale. Default is 1.", + ) + parser.add_argument( + "--seed", + type=int, + default=1337, + help="Specify the seed (set to None if no seed). Default is 1337.", + ) + parser.add_argument( + "--num_channels", + type=int, + default=1, + help="Specify the number of audio output channels. Default is 1.", + ) + parser.add_argument( + "--sample_width", + type=int, + default=2, + help="Specify the number of bytes per sample, adjust if necessary. Default is 2.", + ) + parser.add_argument( + "--emotion_path", + type=str, + default=None, + help="Specify the path to emotion reference.", + ) + parser.add_argument( + "--estimate_max_y_length", + type=str, + default=False, + help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.", + ) + args = parser.parse_args() - model_config = MODELS[args.model_to_use] + model_config = MODELS[args.model_to_use] - # Load the hyperparameters from the config file. - hps = get_hparams_from_file(fetch(model_config[0])) + # Load the hyperparameters from the config file. + hps = get_hparams_from_file(fetch(model_config[0])) - # If model has multiple speakers, validate speaker id and retrieve name if available. - model_has_multiple_speakers = hps.data.n_speakers > 0 - if model_has_multiple_speakers: - logging.info(f"Model has {hps.data.n_speakers} speakers") - if args.speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {args.speaker_id} is invalid for this model.") - speaker_name = "?" - if hps.__contains__("speakers"): # maps speaker ids to names - speakers = hps.speakers - if isinstance(speakers, List): speakers = {speaker: i for i, speaker in enumerate(speakers)} - speaker_name = next((key for key, value in speakers.items() if value == args.speaker_id), None) - logging.info(f"You selected speaker {args.speaker_id} (name: {speaker_name})") + # If model has multiple speakers, validate speaker id and retrieve name if available. + model_has_multiple_speakers = hps.data.n_speakers > 0 + if model_has_multiple_speakers: + logging.info(f"Model has {hps.data.n_speakers} speakers") + if args.speaker_id >= hps.data.n_speakers: + raise ValueError(f"Speaker ID {args.speaker_id} is invalid for this model.") + speaker_name = "?" + if hps.__contains__("speakers"): # maps speaker ids to names + speakers = hps.speakers + if isinstance(speakers, List): + speakers = {speaker: i for i, speaker in enumerate(speakers)} + speaker_name = next( + (key for key, value in speakers.items() if value == args.speaker_id), + None, + ) + logging.info(f"You selected speaker {args.speaker_id} (name: {speaker_name})") - # Load emotions if any. TODO: find an english model with emotions, this is untested atm. - emotion_embedding = None - if args.emotion_path is not None: - if args.emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(args.emotion_path), dtype=dtypes.int64).unsqueeze(0) - else: raise ValueError("Emotion path must be a .npy file.") + # Load emotions if any. TODO: find an english model with emotions, this is untested atm. + emotion_embedding = None + if args.emotion_path is not None: + if args.emotion_path.endswith(".npy"): + emotion_embedding = Tensor( + np.load(args.emotion_path), dtype=dtypes.int64 + ).unsqueeze(0) + else: + raise ValueError("Emotion path must be a .npy file.") - # Load symbols, instantiate TextMapper and clean the text. - if hps.__contains__("symbols"): symbols = hps.symbols - elif args.model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()] - else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ") - text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) + # Load symbols, instantiate TextMapper and clean the text. + if hps.__contains__("symbols"): + symbols = hps.symbols + elif args.model_to_use == "mmts-tts": + symbols = [ + x.replace("\n", "") + for x in fetch( + "https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt" + ) + .open(encoding="utf-8") + .readlines() + ] + else: + symbols = ( + ["_"] + + list(';:,.!?¡¿—…"«»“” ') + + list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + + list( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + ) + ) + text_mapper = TextMapper(apply_cleaners=True, symbols=symbols) - # Load the model. - Tensor.no_grad = True - if args.seed is not None: - Tensor.manual_seed(args.seed) - np.random.seed(args.seed) - net_g = load_model(text_mapper.symbols, hps, model_config) - logging.debug(f"Loaded model with hps: {hps}") + # Load the model. + Tensor.no_grad = True + if args.seed is not None: + Tensor.manual_seed(args.seed) + np.random.seed(args.seed) + net_g = load_model(text_mapper.symbols, hps, model_config) + logging.debug(f"Loaded model with hps: {hps}") - # Convert the input text to a tensor. - text_to_synthesize = args.text_to_synthesize - if args.model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower()) - stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners) - logging.debug(f"Converted input text to tensor \"{text_to_synthesize}\" -> Tensor({stn_tst.shape}): {stn_tst.numpy()}") - x_tst, x_tst_lengths = stn_tst.unsqueeze(0), Tensor([stn_tst.shape[0]], dtype=dtypes.int64) - sid = Tensor([args.speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None + # Convert the input text to a tensor. + text_to_synthesize = args.text_to_synthesize + if args.model_to_use == "mmts-tts": + text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower()) + stn_tst = text_mapper.get_text( + text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners + ) + logging.debug( + f'Converted input text to tensor "{text_to_synthesize}" -> Tensor({stn_tst.shape}): {stn_tst.numpy()}' + ) + x_tst, x_tst_lengths = stn_tst.unsqueeze(0), Tensor( + [stn_tst.shape[0]], dtype=dtypes.int64 + ) + sid = ( + Tensor([args.speaker_id], dtype=dtypes.int64) + if model_has_multiple_speakers + else None + ) - # Perform inference. - start_time = time.time() - audio_tensor = net_g.infer(x_tst, x_tst_lengths, sid, args.noise_scale, args.length_scale, args.noise_scale_w, emotion_embedding=emotion_embedding, - max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[args.model_to_use] if args.estimate_max_y_length else None)[0, 0].realize() - logging.info(f"Inference took {(time.time() - start_time):.2f}s") + # Perform inference. + start_time = time.time() + audio_tensor = net_g.infer( + x_tst, + x_tst_lengths, + sid, + args.noise_scale, + args.length_scale, + args.noise_scale_w, + emotion_embedding=emotion_embedding, + max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[args.model_to_use] + if args.estimate_max_y_length + else None, + )[0, 0].realize() + logging.info(f"Inference took {(time.time() - start_time):.2f}s") - # Save the audio output. - audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16) - out_path = Path(args.out_path or Path(args.out_dir)/f"{args.model_to_use}{f'_sid_{args.speaker_id}' if model_has_multiple_speakers else ''}_{args.base_name}.wav") - out_path.parent.mkdir(parents=True, exist_ok=True) - with wave.open(str(out_path), 'wb') as wav_file: - wav_file.setnchannels(args.num_channels) - wav_file.setsampwidth(args.sample_width) - wav_file.setframerate(hps.data.sampling_rate) - wav_file.setnframes(len(audio_data)) - wav_file.writeframes(audio_data.tobytes()) - logging.info(f"Saved audio output to {out_path}") + # Save the audio output. + audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16) + out_path = Path( + args.out_path + or Path(args.out_dir) + / f"{args.model_to_use}{f'_sid_{args.speaker_id}' if model_has_multiple_speakers else ''}_{args.base_name}.wav" + ) + out_path.parent.mkdir(parents=True, exist_ok=True) + with wave.open(str(out_path), "wb") as wav_file: + wav_file.setnchannels(args.num_channels) + wav_file.setsampwidth(args.sample_width) + wav_file.setframerate(hps.data.sampling_rate) + wav_file.setnframes(len(audio_data)) + wav_file.writeframes(audio_data.tobytes()) + logging.info(f"Saved audio output to {out_path}") diff --git a/examples/webgpu/stable_diffusion/compile.py b/examples/webgpu/stable_diffusion/compile.py index 7984e6313..68ef245ff 100644 --- a/examples/webgpu/stable_diffusion/compile.py +++ b/examples/webgpu/stable_diffusion/compile.py @@ -1,7 +1,13 @@ import os from extra.export_model import compile_net, jit_model from examples.stable_diffusion import StableDiffusion -from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict +from tinygrad.nn.state import ( + get_state_dict, + safe_save, + safe_load_metadata, + torch_load, + load_state_dict, +) from tinygrad.tensor import Tensor from tinygrad import Device from tinygrad.helpers import fetch @@ -10,102 +16,174 @@ from pathlib import Path import argparse import numpy as np + def convert_f32_to_f16(input_file, output_file): - with open(input_file, 'rb') as f: - metadata_length_bytes = f.read(8) - metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False) - metadata_json_bytes = f.read(metadata_length) - float32_values = np.fromfile(f, dtype=np.float32) + with open(input_file, "rb") as f: + metadata_length_bytes = f.read(8) + metadata_length = int.from_bytes( + metadata_length_bytes, byteorder="little", signed=False + ) + metadata_json_bytes = f.read(metadata_length) + float32_values = np.fromfile(f, dtype=np.float32) - first_text_model_offset = 3772703308 - num_elements = int((first_text_model_offset)/4) - front_float16_values = float32_values[:num_elements].astype(np.float16) - rest_float32_values = float32_values[num_elements:] + first_text_model_offset = 3772703308 + num_elements = int((first_text_model_offset) / 4) + front_float16_values = float32_values[:num_elements].astype(np.float16) + rest_float32_values = float32_values[num_elements:] + + with open(output_file, "wb") as f: + f.write(metadata_length_bytes) + f.write(metadata_json_bytes) + front_float16_values.tofile(f) + rest_float32_values.tofile(f) - with open(output_file, 'wb') as f: - f.write(metadata_length_bytes) - f.write(metadata_json_bytes) - front_float16_values.tofile(f) - rest_float32_values.tofile(f) def split_safetensor(fn): - _, json_len, metadata = safe_load_metadata(fn) - text_model_offset = 3772703308 - chunk_size = 536870912 + _, json_len, metadata = safe_load_metadata(fn) + text_model_offset = 3772703308 + chunk_size = 536870912 - for k in metadata: - # safetensor is in fp16, except for text moel - if (metadata[k]["data_offsets"][0] < text_model_offset): - metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2) - metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2) + for k in metadata: + # safetensor is in fp16, except for text moel + if metadata[k]["data_offsets"][0] < text_model_offset: + metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2) + metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2) - last_offset = 0 - part_end_offsets = [] + last_offset = 0 + part_end_offsets = [] - for k in metadata: - offset = metadata[k]['data_offsets'][0] + for k in metadata: + offset = metadata[k]["data_offsets"][0] - if offset == text_model_offset: - break + if offset == text_model_offset: + break - part_offset = offset - last_offset + part_offset = offset - last_offset - if (part_offset >= chunk_size): - part_end_offsets.append(8+json_len+offset) - last_offset = offset + if part_offset >= chunk_size: + part_end_offsets.append(8 + json_len + offset) + last_offset = offset - text_model_start = int(text_model_offset/2) - net_bytes = bytes(open(fn, 'rb').read()) - part_end_offsets.append(text_model_start+8+json_len) - cur_pos = 0 + text_model_start = int(text_model_offset / 2) + net_bytes = bytes(open(fn, "rb").read()) + part_end_offsets.append(text_model_start + 8 + json_len) + cur_pos = 0 - for i, end_pos in enumerate(part_end_offsets): - with open(f'./net_part{i}.safetensors', "wb+") as f: - f.write(net_bytes[cur_pos:end_pos]) - cur_pos = end_pos + for i, end_pos in enumerate(part_end_offsets): + with open(f"./net_part{i}.safetensors", "wb+") as f: + f.write(net_bytes[cur_pos:end_pos]) + cur_pos = end_pos - with open(f'./net_textmodel.safetensors', "wb+") as f: - f.write(net_bytes[text_model_start+8+json_len:]) + with open(f"./net_textmodel.safetensors", "wb+") as f: + f.write(net_bytes[text_model_start + 8 + json_len :]) + + return part_end_offsets - return part_end_offsets if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local") - args = parser.parse_args() - Device.DEFAULT = "WEBGPU" + parser = argparse.ArgumentParser( + description="Run Stable Diffusion", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--remoteweights", + action="store_true", + help="Use safetensors from Huggingface, or from local", + ) + args = parser.parse_args() + Device.DEFAULT = "WEBGPU" - Tensor.no_grad = True - model = StableDiffusion() + Tensor.no_grad = True + model = StableDiffusion() - # load in weights - load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False) + # load in weights + load_state_dict( + model, + torch_load( + fetch( + "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt", + "sd-v1-4.ckpt", + ) + )["state_dict"], + strict=False, + ) - class Step(NamedTuple): - name: str = "" - input: List[Tensor] = [] - forward: Any = None + class Step(NamedTuple): + name: str = "" + input: List[Tensor] = [] + forward: Any = None - sub_steps = [ - Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model), - Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model), - Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode) - ] + sub_steps = [ + Step( + name="textModel", + input=[Tensor.randn(1, 77)], + forward=model.cond_stage_model.transformer.text_model, + ), + Step( + name="diffusor", + input=[ + Tensor.randn(1, 77, 768), + Tensor.randn(1, 77, 768), + Tensor.randn(1, 4, 64, 64), + Tensor.rand(1), + Tensor.randn(1), + Tensor.randn(1), + Tensor.randn(1), + ], + forward=model, + ), + Step(name="decoder", input=[Tensor.randn(1, 4, 64, 64)], forward=model.decode), + ] - prg = "" + prg = "" - def compile_step(model, step: Step): - run, special_names = jit_model(step, *step.input) - functions, statements, bufs, _ = compile_net(run, special_names) - state = get_state_dict(model) - weights = {id(x.lazydata.realized): name for name, x in state.items()} - kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) - kernel_names = ', '.join([name for (name, _, _, _) in statements]) - kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) - bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()]) - gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value]) - input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"]) - return f"""\n var {step.name} = function() {{ + def compile_step(model, step: Step): + run, special_names = jit_model(step, *step.input) + functions, statements, bufs, _ = compile_net(run, special_names) + state = get_state_dict(model) + weights = {id(x.lazydata.realized): name for name, x in state.items()} + kernel_code = "\n\n".join( + [ + f"const {key} = `{code.replace(key, 'main')}`;" + for key, code in functions.items() + ] + ) + kernel_names = ", ".join([name for (name, _, _, _) in statements]) + kernel_calls = "\n ".join( + [ + f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" + for i, (_name, args, global_size, _local_size) in enumerate(statements) + ] + ) + bufs = "\n ".join( + [ + f"const {name} = " + + ( + f"createEmptyBuf(device, {size});" + if _key not in weights + else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))" + ) + + ";" + for name, (size, dtype, _key) in bufs.items() + ] + ) + gpu_write_bufs = "\n ".join( + [ + f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" + for i, (_, value) in enumerate(special_names.items()) + if "output" not in value + ] + ) + input_writer = "\n ".join( + [ + f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + + f"data{i});" + + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" + for i, (_, value) in enumerate(special_names.items()) + if value != "output0" + ] + ) + return f"""\n var {step.name} = function() {{ {kernel_code} @@ -142,23 +220,25 @@ if __name__ == "__main__": }} """ - for step in sub_steps: - print(f'Executing step={step.name}') - prg += compile_step(model, step) + for step in sub_steps: + print(f"Executing step={step.name}") + prg += compile_step(model, step) - if step.name == "diffusor": - if args.remoteweights: - base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main" - else: - state = get_state_dict(model) - safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors")) - convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors") - split_safetensor("./net_conv.safetensors") - os.remove("net.safetensors") - os.remove("net_conv.safetensors") - base_url = "." + if step.name == "diffusor": + if args.remoteweights: + base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main" + else: + state = get_state_dict(model) + safe_save( + state, os.path.join(os.path.dirname(__file__), "net.safetensors") + ) + convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors") + split_safetensor("./net_conv.safetensors") + os.remove("net.safetensors") + os.remove("net_conv.safetensors") + base_url = "." - prekernel = f""" + prekernel = f""" window.MODEL_BASE_URL= "{base_url}"; const getTensorMetadata = (safetensorBuffer) => {{ const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true)); @@ -227,5 +307,5 @@ if __name__ == "__main__": passEncoder.end(); }};""" - with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file: - text_file.write(prekernel + prg) + with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file: + text_file.write(prekernel + prg) diff --git a/examples/whisper.py b/examples/whisper.py index e31f904b2..558b030f4 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -15,338 +15,562 @@ from tinygrad.tensor import Tensor import itertools import librosa + class MultiHeadAttention: - def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None): - self.n_head = n_head - self.query = nn.Linear(n_state, n_state) - self.key = nn.Linear(n_state, n_state, bias=False) - self.value = nn.Linear(n_state, n_state) - self.out = nn.Linear(n_state, n_state) + def __init__( + self, + n_state, + n_head, + kv_caching: Literal["cross", "self"] = None, + max_self_attn_cache_len=None, + ): + self.n_head = n_head + self.query = nn.Linear(n_state, n_state) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state) + self.out = nn.Linear(n_state, n_state) - self.kv_caching = kv_caching - self.max_self_attn_cache_len = max_self_attn_cache_len + self.kv_caching = kv_caching + self.max_self_attn_cache_len = max_self_attn_cache_len - def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None): - if self.kv_caching == 'cross': - if xa is not None: - k, v = self.key(xa), self.value(xa) - if not hasattr(self, 'cache_k'): - self.cache_k, self.cache_v = k, v + def __call__( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + len: Union[Variable, int] = None, + ): + if self.kv_caching == "cross": + if xa is not None: + k, v = self.key(xa), self.value(xa) + if not hasattr(self, "cache_k"): + self.cache_k, self.cache_v = k, v + else: + # see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994 + self.cache_k.assign(k + 1 - 1).realize() + self.cache_v.assign(v + 1 - 1).realize() + else: + k, v = self.cache_k, self.cache_v else: - # see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994 - self.cache_k.assign(k+1-1).realize() - self.cache_v.assign(v+1-1).realize() - else: - k, v = self.cache_k, self.cache_v - else: - k, v = self.key(x), self.value(x) - if self.kv_caching == 'self': - if not hasattr(self, 'cache_k'): - self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) - self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2]) - k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1) - v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1) - padding = self.max_self_attn_cache_len-len-x.shape[1] - self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize() - self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize() + k, v = self.key(x), self.value(x) + if self.kv_caching == "self": + if not hasattr(self, "cache_k"): + self.cache_k = Tensor.zeros( + x.shape[0], self.max_self_attn_cache_len, x.shape[2] + ) + self.cache_v = Tensor.zeros( + x.shape[0], self.max_self_attn_cache_len, x.shape[2] + ) + k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1) + v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1) + padding = self.max_self_attn_cache_len - len - x.shape[1] + self.cache_k.assign( + k.pad((None, (0, padding), None)).contiguous() + ).realize() + self.cache_v.assign( + v.pad((None, (0, padding), None)).contiguous() + ).realize() - q = self.query(x) - n_ctx = q.shape[1] - assert(q.shape[-1] == k.shape[-1] == v.shape[-1]) - head_dim = q.shape[-1] // self.n_head - q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) - k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) - v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) - attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None) - wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2) - return self.out(wv) + q = self.query(x) + n_ctx = q.shape[1] + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + head_dim = q.shape[-1] // self.n_head + q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) + attn = Tensor.scaled_dot_product_attention( + q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None + ) + wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2) + return self.out(wv) class ResidualAttentionBlock: - def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None): - self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len) - self.attn_ln = nn.LayerNorm(n_state) + def __init__( + self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None + ): + self.attn = MultiHeadAttention( + n_state, + n_head, + kv_caching="self" if is_decoder_block else None, + max_self_attn_cache_len=max_self_attn_cache_len, + ) + self.attn_ln = nn.LayerNorm(n_state) - self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None - self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None + self.cross_attn = ( + MultiHeadAttention(n_state, n_head, kv_caching="cross") + if is_decoder_block + else None + ) + self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None - self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)] - self.mlp_ln = nn.LayerNorm(n_state) + self.mlp = [ + nn.Linear(n_state, n_state * 4), + Tensor.gelu, + nn.Linear(n_state * 4, n_state), + ] + self.mlp_ln = nn.LayerNorm(n_state) + + def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None): + x = x + self.attn(self.attn_ln(x), mask=mask, len=len) + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa) + x = x + self.mlp_ln(x).sequential(self.mlp) + return x.realize() - def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None): - x = x + self.attn(self.attn_ln(x), mask=mask, len=len) - if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa) - x = x + self.mlp_ln(x).sequential(self.mlp) - return x.realize() class AudioEncoder: - def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_): - self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1) - self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1) - self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)] - self.ln_post = nn.LayerNorm(n_audio_state) - self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state) - self.encode = TinyJit(self.__call__) + def __init__( + self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_ + ): + self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d( + n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1 + ) + self.blocks = [ + ResidualAttentionBlock(n_audio_state, n_audio_head) + for _ in range(n_audio_layer) + ] + self.ln_post = nn.LayerNorm(n_audio_state) + self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state) + self.encode = TinyJit(self.__call__) + + def __call__(self, x): + x = self.conv1(x).gelu() + x = self.conv2(x).gelu() + x = x.permute(0, 2, 1) + x = x + self.positional_embedding[: x.shape[1]] + x = x.sequential(self.blocks) + x = self.ln_post(x) + return x.realize() - def __call__(self, x): - x = self.conv1(x).gelu() - x = self.conv2(x).gelu() - x = x.permute(0, 2, 1) - x = x + self.positional_embedding[:x.shape[1]] - x = x.sequential(self.blocks) - x = self.ln_post(x) - return x.realize() class TextDecoder: - def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_): - self.max_tokens_to_sample = n_text_ctx // 2 - self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample + def __init__( + self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_ + ): + self.max_tokens_to_sample = n_text_ctx // 2 + self.max_self_attn_cache_len = ( + self.max_tokens_to_sample * 2 + 5 + ) # roughly prompt + start toks + max_tokens_to_sample - self.token_embedding = nn.Embedding(n_vocab, n_text_state) - self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state) - self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)] - self.ln = nn.LayerNorm(n_text_state) - self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize() - self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks] - self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks] - self.start_output_tok = TinyJit(self.output_tok) - self.after_start_output_tok = TinyJit(self.output_tok) + self.token_embedding = nn.Embedding(n_vocab, n_text_state) + self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state) + self.blocks = [ + ResidualAttentionBlock( + n_text_state, + n_text_head, + is_decoder_block=True, + max_self_attn_cache_len=self.max_self_attn_cache_len, + ) + for _ in range(n_text_layer) + ] + self.ln = nn.LayerNorm(n_text_state) + self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize() + self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks] + self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks] + self.start_output_tok = TinyJit(self.output_tok) + self.after_start_output_tok = TinyJit(self.output_tok) - # if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only') - def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False): - seqlen = x.shape[-1] - x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen] - if pos == 0: - for block in (self.blocks if streaming else self.blocks_start_tok): - x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching - return self.output_tok(x) if streaming else self.start_output_tok(x) - else: - for block in self.blocks_after_start_tok: - len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos) - x = block(x, mask=self.mask, len=len_v) - return self.after_start_output_tok(x) + # if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only') + def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False): + seqlen = x.shape[-1] + x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen] + if pos == 0: + for block in self.blocks if streaming else self.blocks_start_tok: + x = block( + x, xa=encoded_audio, mask=self.mask, len=0 + ) # pass xa for cross attn kv caching + return self.output_tok(x) if streaming else self.start_output_tok(x) + else: + for block in self.blocks_after_start_tok: + len_v = Variable( + "self_attn_cache_len", 1, self.max_self_attn_cache_len + ).bind(pos) + x = block(x, mask=self.mask, len=len_v) + return self.after_start_output_tok(x) + + def output_tok(self, x): + return (self.ln(x) @ self.token_embedding.weight.T).realize() - def output_tok(self, x): - return (self.ln(x) @ self.token_embedding.weight.T).realize() class Whisper: - def __init__(self, dims, batch_size=1): - self.encoder = AudioEncoder(**dims) - self.decoder = TextDecoder(**dims) - self.is_multilingual = dims["n_vocab"] == 51865 - self.batch_size = batch_size + def __init__(self, dims, batch_size=1): + self.encoder = AudioEncoder(**dims) + self.decoder = TextDecoder(**dims) + self.is_multilingual = dims["n_vocab"] == 51865 + self.batch_size = batch_size RATE = 16000 -SEGMENT_SECONDS=30 -SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000 +SEGMENT_SECONDS = 30 +SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000 N_FFT = 400 HOP_LENGTH = 160 N_MELS = 80 -FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000 +FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000 -def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray: - """ - :param waveforms: A list of possibly variable length 16000Hz audio samples - :param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio. - Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes - :param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass - :return: mel spectrogram of the given waveforms - """ - def pad_or_trim(arr, target_len): - curr_len = len(arr) - if curr_len == target_len: - return arr - elif curr_len < target_len: - return np.pad(arr, (0, target_len - curr_len), 'constant') - else: - return arr[:target_len] - max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms) - if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r - waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms))) - assert waveforms.shape[0] <= batch_size - if waveforms.shape[0] < batch_size: - # we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes - waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))) +def prep_audio( + waveforms: List[np.ndarray], batch_size: int, truncate=False +) -> np.ndarray: + """ + :param waveforms: A list of possibly variable length 16000Hz audio samples + :param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio. + Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes + :param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass + :return: mel spectrogram of the given waveforms + """ - stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle) - magnitudes = np.absolute(stft[..., :-1]) ** 2 - mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes + def pad_or_trim(arr, target_len): + curr_len = len(arr) + if curr_len == target_len: + return arr + elif curr_len < target_len: + return np.pad(arr, (0, target_len - curr_len), "constant") + else: + return arr[:target_len] - log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) - log_spec = np.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 + max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms) + if (r := max_len % SAMPLES_PER_SEGMENT) > 0: + max_len += SAMPLES_PER_SEGMENT - r + waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms))) + assert waveforms.shape[0] <= batch_size + if waveforms.shape[0] < batch_size: + # we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes + waveforms = np.pad( + waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)) + ) + + stft = librosa.stft( + waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle + ) + magnitudes = np.absolute(stft[..., :-1]) ** 2 + mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes + + log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec - return log_spec LANGUAGES = { - "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", - "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", - "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", - "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", - "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", - "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", - "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", - "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", - "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", - "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", } + def get_encoding(encoding_name): - with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f: - ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)} - n_vocab = len(ranks) - specials = [ - "<|endoftext|>", - "<|startoftranscript|>", - *[f"<|{lang}|>" for lang in LANGUAGES.keys()], - "<|translate|>", - "<|transcribe|>", - "<|startoflm|>", - "<|startofprev|>", - "<|nospeech|>", - "<|notimestamps|>", - *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], - ] - special_tokens = dict(zip(specials, itertools.count(n_vocab))) - n_vocab += len(specials) - import tiktoken - return tiktoken.Encoding( - name=encoding_name, - explicit_n_vocab=n_vocab, - pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - mergeable_ranks=ranks, - special_tokens=special_tokens) + with fetch( + f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken" + ).open() as f: + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in f if line) + } + n_vocab = len(ranks) + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + special_tokens = dict(zip(specials, itertools.count(n_vocab))) + n_vocab += len(specials) + import tiktoken + + return tiktoken.Encoding( + name=encoding_name, + explicit_n_vocab=n_vocab, + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + MODEL_URLS = { - "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", - "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", - "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", - "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", - "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", - "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", - "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", - "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", - "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", - "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", - "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", + "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", + "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", + "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", + "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", + "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", + "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", + "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", + "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", + "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } -def init_whisper(model_name="tiny.en", batch_size=1): - assert MODEL_URLS[model_name] is not None - filename = fetch(MODEL_URLS[model_name]) - state = torch_load(filename) - model = Whisper(state['dims'], batch_size) - load_state_dict(model, state['model_state_dict'], strict=False) - enc = get_encoding("multilingual" if model.is_multilingual else "gpt2") - return model, enc + +def init_whisper(model_name="tiny.en", batch_size=1): + assert MODEL_URLS[model_name] is not None + + filename = fetch(MODEL_URLS[model_name]) + state = torch_load(filename) + model = Whisper(state["dims"], batch_size) + load_state_dict(model, state["model_state_dict"], strict=False) + enc = get_encoding("multilingual" if model.is_multilingual else "gpt2") + return model, enc + def load_file_waveform(filename): - waveform, _ = librosa.load(filename, sr=RATE) - return waveform + waveform, _ = librosa.load(filename, sr=RATE) + return waveform + def transcribe_file(model, enc, filename): - return transcribe_waveform(model, enc, [load_file_waveform(filename)]) + return transcribe_waveform(model, enc, [load_file_waveform(filename)]) + def transcribe_waveform(model, enc, waveforms, truncate=False): - """ - Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples - Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided - """ - N_audio = len(waveforms) - log_spec = prep_audio(waveforms, model.batch_size, truncate) + """ + Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples + Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided + """ + N_audio = len(waveforms) + log_spec = prep_audio(waveforms, model.batch_size, truncate) - if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1: - # we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch - # if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent - raise Exception("Multi-segment transcription not supported with batch audio input") + if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1: + # we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch + # if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent + raise Exception( + "Multi-segment transcription not supported with batch audio input" + ) - start_tokens = [enc._special_tokens["<|startoftranscript|>"]] - if model.is_multilingual: - # TODO detect language - language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en") - start_tokens.append(language_token) - start_tokens.append(enc._special_tokens["<|transcribe|>"]) - start_tokens.append(enc._special_tokens["<|notimestamps|>"]) - transcription_start_index = len(start_tokens) - eot = enc._special_tokens["<|endoftext|>"] - transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0] + start_tokens = [enc._special_tokens["<|startoftranscript|>"]] + if model.is_multilingual: + # TODO detect language + language_token = ( + enc._special_tokens["<|startoftranscript|>"] + + 1 + + tuple(LANGUAGES.keys()).index("en") + ) + start_tokens.append(language_token) + start_tokens.append(enc._special_tokens["<|transcribe|>"]) + start_tokens.append(enc._special_tokens["<|notimestamps|>"]) + transcription_start_index = len(start_tokens) + eot = enc._special_tokens["<|endoftext|>"] + transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0] - for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT): - encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT])) - pos = 0 - curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1)) - if curr_frame > 0: - # pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 - prompt = np.concatenate(( - [enc._special_tokens["<|startofprev|>"]], - transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:], - start_tokens)) - curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1)) - transcription_start_index = len(curr_segment_tokens[0]) + for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT): + encoded_audio = model.encoder.encode( + Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT]) + ) + pos = 0 + curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1)) + if curr_frame > 0: + # pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt = np.concatenate( + ( + [enc._special_tokens["<|startofprev|>"]], + transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :], + start_tokens, + ) + ) + curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1)) + transcription_start_index = len(curr_segment_tokens[0]) - for i in range(model.decoder.max_tokens_to_sample): - out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0) - next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32) - next_tokens[curr_segment_tokens[:, -1] == eot] = eot - curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1) - pos = curr_segment_tokens.shape[-1] - 1 - if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))) - if (curr_segment_tokens[:, -1] == eot).all(): - break + for i in range(model.decoder.max_tokens_to_sample): + out = model.decoder( + Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), + pos, + encoded_audio, + streaming=curr_frame > 0, + ) + next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32) + next_tokens[curr_segment_tokens[:, -1] == eot] = eot + curr_segment_tokens = np.concatenate( + (curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1 + ) + pos = curr_segment_tokens.shape[-1] - 1 + if DEBUG >= 1: + print( + i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)) + ) + if (curr_segment_tokens[:, -1] == eot).all(): + break - for i, t in enumerate(curr_segment_tokens): - eot_index = np.where(t == eot)[0] - eot_index = None if len(eot_index) == 0 else eot_index[0] - transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index])) + for i, t in enumerate(curr_segment_tokens): + eot_index = np.where(t == eot)[0] + eot_index = None if len(eot_index) == 0 else eot_index[0] + transcription_tokens[i] = np.concatenate( + (transcription_tokens[i], t[transcription_start_index:eot_index]) + ) + + transcriptions = list( + map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens) + ) + return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0] - transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)) - return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0] CHUNK = 1600 RECORD_SECONDS = 10 + def listener(q): - import pyaudio - p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK) - print("listening") - for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)): - data = stream.read(CHUNK) - waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3) - q.put(waveform) - print("done listening") + import pyaudio + + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt16, + channels=1, + rate=RATE, + input=True, + frames_per_buffer=CHUNK, + ) + print("listening") + for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)): + data = stream.read(CHUNK) + waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3 + q.put(waveform) + print("done listening") + if __name__ == "__main__": - model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1) + model, enc = init_whisper( + "small.en" if getenv("SMALL") else "tiny.en", batch_size=1 + ) - if len(sys.argv) > 1: - print(transcribe_file(model, enc, sys.argv[1])) - else: - # online - q = multiprocessing.Queue() - p = multiprocessing.Process(target=listener, args=(q,)) - p.daemon = True - p.start() + if len(sys.argv) > 1: + print(transcribe_file(model, enc, sys.argv[1])) + else: + # online + q = multiprocessing.Queue() + p = multiprocessing.Process(target=listener, args=(q,)) + p.daemon = True + p.start() - lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]] - total = None - did_read = False - for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)): - while not q.empty() or total is None: - waveform = q.get() - if total is None: total = waveform - else: total = np.concatenate([total, waveform]) - did_read = True - if did_read: - log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True) - encoded_audio = model.encoder.encode(Tensor(log_spec)) - # pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 - out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize() - idx = int(out[0,-1].argmax().numpy().item()) - lst.append(idx) - dec = enc.decode(lst) - print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT - if dec.endswith("<|endoftext|>"): - lst.pop() + lst = [ + enc._special_tokens["<|startoftranscript|>"], + enc._special_tokens["<|notimestamps|>"], + ] + total = None + did_read = False + for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)): + while not q.empty() or total is None: + waveform = q.get() + if total is None: + total = waveform + else: + total = np.concatenate([total, waveform]) + did_read = True + if did_read: + log_spec = prep_audio( + total.reshape(1, -1), model.batch_size, truncate=True + ) + encoded_audio = model.encoder.encode(Tensor(log_spec)) + # pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + out = model.decoder( + Tensor([lst]), 0, encoded_audio, streaming=True + ).realize() + idx = int(out[0, -1].argmax().numpy().item()) + lst.append(idx) + dec = enc.decode(lst) + print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT + if dec.endswith("<|endoftext|>"): + lst.pop() diff --git a/examples/yolov3.py b/examples/yolov3.py index ce137e6c8..9d4cced19 100755 --- a/examples/yolov3.py +++ b/examples/yolov3.py @@ -10,397 +10,462 @@ from tinygrad.tensor import Tensor from tinygrad.nn import BatchNorm2d, Conv2d from tinygrad.helpers import fetch + def show_labels(prediction, confidence=0.5, num_classes=80): - coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes() - coco_labels = coco_labels.decode('utf-8').split('\n') - prediction = prediction.detach().numpy() - conf_mask = (prediction[:,:,4] > confidence) - prediction *= np.expand_dims(conf_mask, 2) - labels = [] - # Iterate over batches - for img_pred in prediction: - max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1) - max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1) - max_conf_score = np.expand_dims(max_conf_score, axis=1) - max_conf = np.expand_dims(max_conf, axis=1) - seq = (img_pred[:,:5], max_conf, max_conf_score) - image_pred = np.concatenate(seq, axis=1) - non_zero_ind = np.nonzero(image_pred[:,4])[0] - assert all(image_pred[non_zero_ind,0] > 0) - image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - classes, indexes = np.unique(image_pred_[:, -1], return_index=True) - for index, coco_class in enumerate(classes): - label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100 - print(f"Detected {label} {probability:.2f}") - labels.append(label) - return labels + coco_labels = fetch( + "https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names" + ).read_bytes() + coco_labels = coco_labels.decode("utf-8").split("\n") + prediction = prediction.detach().numpy() + conf_mask = prediction[:, :, 4] > confidence + prediction *= np.expand_dims(conf_mask, 2) + labels = [] + # Iterate over batches + for img_pred in prediction: + max_conf = np.amax(img_pred[:, 5 : 5 + num_classes], axis=1) + max_conf_score = np.argmax(img_pred[:, 5 : 5 + num_classes], axis=1) + max_conf_score = np.expand_dims(max_conf_score, axis=1) + max_conf = np.expand_dims(max_conf, axis=1) + seq = (img_pred[:, :5], max_conf, max_conf_score) + image_pred = np.concatenate(seq, axis=1) + non_zero_ind = np.nonzero(image_pred[:, 4])[0] + assert all(image_pred[non_zero_ind, 0] > 0) + image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7)) + classes, indexes = np.unique(image_pred_[:, -1], return_index=True) + for index, coco_class in enumerate(classes): + label, probability = ( + coco_labels[int(coco_class)], + image_pred_[indexes[index]][4] * 100, + ) + print(f"Detected {label} {probability:.2f}") + labels.append(label) + return labels + def add_boxes(img, prediction): - if isinstance(prediction, int): # no predictions + if isinstance(prediction, int): # no predictions + return img + coco_labels = fetch( + "https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names" + ) + coco_labels = coco_labels.decode("utf-8").split("\n") + height, width = img.shape[0:2] + scale_factor = 608 / width + prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2 + prediction[:, [2, 4]] -= (608 - scale_factor * height) / 2 + for pred in prediction: + corner1 = tuple(pred[1:3].astype(int)) + corner2 = tuple(pred[3:5].astype(int)) + w = corner2[0] - corner1[0] + h = corner2[1] - corner1[1] + corner2 = (corner2[0] + w, corner2[1] + h) + label = coco_labels[int(pred[-1])] + img = cv2.rectangle(img, corner1, corner2, (255, 0, 0), 2) + t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] + c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4 + img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1) + img = cv2.putText( + img, + label, + (corner1[0], corner1[1] + t_size[1] + 4), + cv2.FONT_HERSHEY_PLAIN, + 1, + [225, 255, 255], + 1, + ) return img - coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names') - coco_labels = coco_labels.decode('utf-8').split('\n') - height, width = img.shape[0:2] - scale_factor = 608 / width - prediction[:,[1,3]] -= (608 - scale_factor * width) / 2 - prediction[:,[2,4]] -= (608 - scale_factor * height) / 2 - for pred in prediction: - corner1 = tuple(pred[1:3].astype(int)) - corner2 = tuple(pred[3:5].astype(int)) - w = corner2[0] - corner1[0] - h = corner2[1] - corner1[1] - corner2 = (corner2[0] + w, corner2[1] + h) - label = coco_labels[int(pred[-1])] - img = cv2.rectangle(img, corner1, corner2, (255, 0, 0), 2) - t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0] - c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4 - img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1) - img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1) - return img + def bbox_iou(box1, box2): - """ - Returns the IoU of two bounding boxes - IoU: IoU = Area Of Overlap / Area of Union -> How close the predicted bounding box is - to the ground truth bounding box. Higher IoU = Better accuracy - In training, used to track accuracy. with inference, using to remove duplicate bounding boxes - """ - # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] - # get the coordinates of the intersection rectangle - inter_rect_x1 = np.maximum(b1_x1, b2_x1) - inter_rect_y1 = np.maximum(b1_y1, b2_y1) - inter_rect_x2 = np.maximum(b1_x2, b2_x2) - inter_rect_y2 = np.maximum(b1_y2, b2_y2) - #Intersection area - inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999) - #Union Area - b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) - b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) - iou = inter_area / (b1_area + b2_area - inter_area) - return iou + """ + Returns the IoU of two bounding boxes + IoU: IoU = Area Of Overlap / Area of Union -> How close the predicted bounding box is + to the ground truth bounding box. Higher IoU = Better accuracy + In training, used to track accuracy. with inference, using to remove duplicate bounding boxes + """ + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + # get the coordinates of the intersection rectangle + inter_rect_x1 = np.maximum(b1_x1, b2_x1) + inter_rect_y1 = np.maximum(b1_y1, b2_y1) + inter_rect_x2 = np.maximum(b1_x2, b2_x2) + inter_rect_y2 = np.maximum(b1_y2, b2_y2) + # Intersection area + inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip( + inter_rect_y2 - inter_rect_y1 + 1, 0, 99999 + ) + # Union Area + b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) + b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) + iou = inter_area / (b1_area + b2_area - inter_area) + return iou + def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4): - prediction = prediction.detach().numpy() - conf_mask = (prediction[:,:,4] > confidence) - conf_mask = np.expand_dims(conf_mask, 2) - prediction = prediction * conf_mask - # Non max suppression - box_corner = prediction - box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2) - box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2) - box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2) - box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2) - prediction[:,:,:4] = box_corner[:,:,:4] - write = False - # Process img - img_pred = prediction[0] - max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1) - max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1) - max_conf_score = np.expand_dims(max_conf_score, axis=1) - max_conf = np.expand_dims(max_conf, axis=1) - seq = (img_pred[:,:5], max_conf, max_conf_score) - image_pred = np.concatenate(seq, axis=1) - non_zero_ind = np.nonzero(image_pred[:,4])[0] - assert all(image_pred[non_zero_ind,0] > 0) - image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7)) - if image_pred_.shape[0] == 0: - print("No detections found!") - return 0 - for cls in np.unique(image_pred_[:, -1]): - # perform NMS, get the detections with one particular class - cls_mask = image_pred_*np.expand_dims(image_pred_[:, -1] == cls, axis=1) - class_mask_ind = np.squeeze(np.nonzero(cls_mask[:,-2])) - # class_mask_ind = np.nonzero() - image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7)) - # sort the detections such that the entry with the maximum objectness - # confidence is at the top - conf_sort_index = np.argsort(image_pred_class[:,4]) - image_pred_class = image_pred_class[conf_sort_index] - for i in range(image_pred_class.shape[0]): - # Get the IOUs of all boxes that come after the one we are looking at in the loop - try: - ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:]) - except: - break - # Zero out all the detections that have IoU > threshold - iou_mask = np.expand_dims((ious < nms_conf), axis=1) - image_pred_class[i+1:] *= iou_mask - # Remove the non-zero entries - non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:,4])) - image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7)) - batch_ind = np.array([[0]]) - seq = (batch_ind, image_pred_class) - if not write: - output, write = np.concatenate(seq, axis=1), True - else: - out = np.concatenate(seq, axis=1) - output = np.concatenate((output,out)) - return output + prediction = prediction.detach().numpy() + conf_mask = prediction[:, :, 4] > confidence + conf_mask = np.expand_dims(conf_mask, 2) + prediction = prediction * conf_mask + # Non max suppression + box_corner = prediction + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + write = False + # Process img + img_pred = prediction[0] + max_conf = np.amax(img_pred[:, 5 : 5 + num_classes], axis=1) + max_conf_score = np.argmax(img_pred[:, 5 : 5 + num_classes], axis=1) + max_conf_score = np.expand_dims(max_conf_score, axis=1) + max_conf = np.expand_dims(max_conf, axis=1) + seq = (img_pred[:, :5], max_conf, max_conf_score) + image_pred = np.concatenate(seq, axis=1) + non_zero_ind = np.nonzero(image_pred[:, 4])[0] + assert all(image_pred[non_zero_ind, 0] > 0) + image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7)) + if image_pred_.shape[0] == 0: + print("No detections found!") + return 0 + for cls in np.unique(image_pred_[:, -1]): + # perform NMS, get the detections with one particular class + cls_mask = image_pred_ * np.expand_dims(image_pred_[:, -1] == cls, axis=1) + class_mask_ind = np.squeeze(np.nonzero(cls_mask[:, -2])) + # class_mask_ind = np.nonzero() + image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7)) + # sort the detections such that the entry with the maximum objectness + # confidence is at the top + conf_sort_index = np.argsort(image_pred_class[:, 4]) + image_pred_class = image_pred_class[conf_sort_index] + for i in range(image_pred_class.shape[0]): + # Get the IOUs of all boxes that come after the one we are looking at in the loop + try: + ious = bbox_iou( + np.expand_dims(image_pred_class[i], axis=0), + image_pred_class[i + 1 :], + ) + except: + break + # Zero out all the detections that have IoU > threshold + iou_mask = np.expand_dims((ious < nms_conf), axis=1) + image_pred_class[i + 1 :] *= iou_mask + # Remove the non-zero entries + non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:, 4])) + image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7)) + batch_ind = np.array([[0]]) + seq = (batch_ind, image_pred_class) + if not write: + output, write = np.concatenate(seq, axis=1), True + else: + out = np.concatenate(seq, axis=1) + output = np.concatenate((output, out)) + return output + def infer(model, img): - img = np.array(Image.fromarray(img).resize((608, 608))) - img = img[:,:,::-1].transpose((2,0,1)) - img = img[np.newaxis,:,:,:]/255.0 - prediction = model.forward(Tensor(img.astype(np.float32))) - return prediction + img = np.array(Image.fromarray(img).resize((608, 608))) + img = img[:, :, ::-1].transpose((2, 0, 1)) + img = img[np.newaxis, :, :, :] / 255.0 + prediction = model.forward(Tensor(img.astype(np.float32))) + return prediction def parse_cfg(cfg): - # Return a list of blocks - lines = cfg.decode("utf-8").split('\n') - lines = [x for x in lines if len(x) > 0] - lines = [x for x in lines if x[0] != '#'] - lines = [x.rstrip().lstrip() for x in lines] - block, blocks = {}, [] - for line in lines: - if line[0] == "[": - if len(block) != 0: - blocks.append(block) - block = {} - block["type"] = line[1:-1].rstrip() - else: - key,value = line.split("=") - block[key.rstrip()] = value.lstrip() - blocks.append(block) - return blocks + # Return a list of blocks + lines = cfg.decode("utf-8").split("\n") + lines = [x for x in lines if len(x) > 0] + lines = [x for x in lines if x[0] != "#"] + lines = [x.rstrip().lstrip() for x in lines] + block, blocks = {}, [] + for line in lines: + if line[0] == "[": + if len(block) != 0: + blocks.append(block) + block = {} + block["type"] = line[1:-1].rstrip() + else: + key, value = line.split("=") + block[key.rstrip()] = value.lstrip() + blocks.append(block) + return blocks + # TODO: Speed up this function, avoid copying stuff from GPU to CPU def predict_transform(prediction, inp_dim, anchors, num_classes): - batch_size = prediction.shape[0] - stride = inp_dim // prediction.shape[2] - grid_size = inp_dim // stride - bbox_attrs = 5 + num_classes - num_anchors = len(anchors) - prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size)) - prediction = prediction.transpose(1, 2) - prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs)) - prediction_cpu = prediction.numpy() - for i in (0, 1, 4): - prediction_cpu[:,:,i] = 1 / (1 + np.exp(-prediction_cpu[:,:,i])) - # Add the center offsets - grid = np.arange(grid_size) - a, b = np.meshgrid(grid, grid) - x_offset = a.reshape((-1, 1)) - y_offset = b.reshape((-1, 1)) - x_y_offset = np.concatenate((x_offset, y_offset), 1) - x_y_offset = np.tile(x_y_offset, (1, num_anchors)) - x_y_offset = x_y_offset.reshape((-1,2)) - x_y_offset = np.expand_dims(x_y_offset, 0) - anchors = [(a[0]/stride, a[1]/stride) for a in anchors] - anchors = np.tile(anchors, (grid_size*grid_size, 1)) - anchors = np.expand_dims(anchors, 0) - prediction_cpu[:,:,:2] += x_y_offset - prediction_cpu[:,:,2:4] = np.exp(prediction_cpu[:,:,2:4])*anchors - prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes])) - prediction_cpu[:,:,:4] *= stride - return Tensor(prediction_cpu) + batch_size = prediction.shape[0] + stride = inp_dim // prediction.shape[2] + grid_size = inp_dim // stride + bbox_attrs = 5 + num_classes + num_anchors = len(anchors) + prediction = prediction.reshape( + shape=(batch_size, bbox_attrs * num_anchors, grid_size * grid_size) + ) + prediction = prediction.transpose(1, 2) + prediction = prediction.reshape( + shape=(batch_size, grid_size * grid_size * num_anchors, bbox_attrs) + ) + prediction_cpu = prediction.numpy() + for i in (0, 1, 4): + prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i])) + # Add the center offsets + grid = np.arange(grid_size) + a, b = np.meshgrid(grid, grid) + x_offset = a.reshape((-1, 1)) + y_offset = b.reshape((-1, 1)) + x_y_offset = np.concatenate((x_offset, y_offset), 1) + x_y_offset = np.tile(x_y_offset, (1, num_anchors)) + x_y_offset = x_y_offset.reshape((-1, 2)) + x_y_offset = np.expand_dims(x_y_offset, 0) + anchors = [(a[0] / stride, a[1] / stride) for a in anchors] + anchors = np.tile(anchors, (grid_size * grid_size, 1)) + anchors = np.expand_dims(anchors, 0) + prediction_cpu[:, :, :2] += x_y_offset + prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors + prediction_cpu[:, :, 5 : 5 + num_classes] = 1 / ( + 1 + np.exp(-prediction_cpu[:, :, 5 : 5 + num_classes]) + ) + prediction_cpu[:, :, :4] *= stride + return Tensor(prediction_cpu) class Darknet: - def __init__(self, cfg): - self.blocks = parse_cfg(cfg) - self.net_info, self.module_list = self.create_modules(self.blocks) - print("Modules length:", len(self.module_list)) + def __init__(self, cfg): + self.blocks = parse_cfg(cfg) + self.net_info, self.module_list = self.create_modules(self.blocks) + print("Modules length:", len(self.module_list)) - def create_modules(self, blocks): - net_info = blocks[0] # Info about model hyperparameters - prev_filters, filters = 3, None - output_filters, module_list = [], [] - ## module - for index, x in enumerate(blocks[1:]): - module_type = x["type"] - module = [] - if module_type == "convolutional": - try: - batch_normalize, bias = int(x["batch_normalize"]), False - except: - batch_normalize, bias = 0, True - # layer - activation = x["activation"] - filters = int(x["filters"]) - padding = int(x["pad"]) - pad = (int(x["size"]) - 1) // 2 if padding else 0 - module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias)) - # BatchNorm2d - if batch_normalize: - module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True)) - # LeakyReLU activation - if activation == "leaky": - module.append(lambda x: x.leakyrelu(0.1)) - elif module_type == "maxpool": - size, stride = int(x["size"]), int(x["stride"]) - module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)) - elif module_type == "upsample": - module.append(lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1))) - elif module_type == "route": - x["layers"] = x["layers"].split(",") - # Start of route - start = int(x["layers"][0]) - # End if it exists - try: - end = int(x["layers"][1]) - except: - end = 0 - if start > 0: start -= index - if end > 0: end -= index - module.append(lambda x: x) - if end < 0: - filters = output_filters[index + start] + output_filters[index + end] - else: - filters = output_filters[index + start] - # Shortcut corresponds to skip connection - elif module_type == "shortcut": - module.append(lambda x: x) - elif module_type == "yolo": - mask = list(map(int, x["mask"].split(","))) - anchors = [int(a) for a in x["anchors"].split(",")] - anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)] - module.append([anchors[i] for i in mask]) - # Append to module_list - module_list.append(module) - if filters is not None: - prev_filters = filters - output_filters.append(filters) - return (net_info, module_list) + def create_modules(self, blocks): + net_info = blocks[0] # Info about model hyperparameters + prev_filters, filters = 3, None + output_filters, module_list = [], [] + ## module + for index, x in enumerate(blocks[1:]): + module_type = x["type"] + module = [] + if module_type == "convolutional": + try: + batch_normalize, bias = int(x["batch_normalize"]), False + except: + batch_normalize, bias = 0, True + # layer + activation = x["activation"] + filters = int(x["filters"]) + padding = int(x["pad"]) + pad = (int(x["size"]) - 1) // 2 if padding else 0 + module.append( + Conv2d( + prev_filters, + filters, + int(x["size"]), + int(x["stride"]), + pad, + bias=bias, + ) + ) + # BatchNorm2d + if batch_normalize: + module.append( + BatchNorm2d(filters, eps=1e-05, track_running_stats=True) + ) + # LeakyReLU activation + if activation == "leaky": + module.append(lambda x: x.leakyrelu(0.1)) + elif module_type == "maxpool": + size, stride = int(x["size"]), int(x["stride"]) + module.append( + lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride) + ) + elif module_type == "upsample": + module.append( + lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1)) + ) + elif module_type == "route": + x["layers"] = x["layers"].split(",") + # Start of route + start = int(x["layers"][0]) + # End if it exists + try: + end = int(x["layers"][1]) + except: + end = 0 + if start > 0: + start -= index + if end > 0: + end -= index + module.append(lambda x: x) + if end < 0: + filters = ( + output_filters[index + start] + output_filters[index + end] + ) + else: + filters = output_filters[index + start] + # Shortcut corresponds to skip connection + elif module_type == "shortcut": + module.append(lambda x: x) + elif module_type == "yolo": + mask = list(map(int, x["mask"].split(","))) + anchors = [int(a) for a in x["anchors"].split(",")] + anchors = [ + (anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2) + ] + module.append([anchors[i] for i in mask]) + # Append to module_list + module_list.append(module) + if filters is not None: + prev_filters = filters + output_filters.append(filters) + return (net_info, module_list) - def dump_weights(self): - for i in range(len(self.module_list)): - module_type = self.blocks[i + 1]["type"] - if module_type == "convolutional": - print(self.blocks[i + 1]["type"], "weights", i) - model = self.module_list[i] - conv = model[0] - print(conv.weight.numpy()[0][0][0]) - if conv.bias is not None: - print("biases") - print(conv.bias.shape) - print(conv.bias.numpy()[0][0:5]) - else: - print("None biases for layer", i) + def dump_weights(self): + for i in range(len(self.module_list)): + module_type = self.blocks[i + 1]["type"] + if module_type == "convolutional": + print(self.blocks[i + 1]["type"], "weights", i) + model = self.module_list[i] + conv = model[0] + print(conv.weight.numpy()[0][0][0]) + if conv.bias is not None: + print("biases") + print(conv.bias.shape) + print(conv.bias.numpy()[0][0:5]) + else: + print("None biases for layer", i) - def load_weights(self, url): - weights = np.frombuffer(fetch(url), dtype=np.float32)[5:] - ptr = 0 - for i in range(len(self.module_list)): - module_type = self.blocks[i + 1]["type"] - if module_type == "convolutional": - model = self.module_list[i] - try: # we have batchnorm, load conv weights without biases, and batchnorm values - batch_normalize = int(self.blocks[i+1]["batch_normalize"]) - except: # no batchnorm, load conv weights + biases - batch_normalize = 0 - conv = model[0] - if batch_normalize: - bn = model[1] - # Get the number of weights of batchnorm - num_bn_biases = math.prod(bn.bias.shape) - # Load weights - bn_biases = Tensor(weights[ptr:ptr + num_bn_biases]) - ptr += num_bn_biases - bn_weights = Tensor(weights[ptr:ptr+num_bn_biases]) - ptr += num_bn_biases - bn_running_mean = Tensor(weights[ptr:ptr+num_bn_biases]) - ptr += num_bn_biases - bn_running_var = Tensor(weights[ptr:ptr+num_bn_biases]) - ptr += num_bn_biases - # Cast the loaded weights into dims of model weights - bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape)) - bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape)) - bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape)) - bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape)) - # Copy data - bn.bias = bn_biases - bn.weight = bn_weights - bn.running_mean = bn_running_mean - bn.running_var = bn_running_var - else: - # load biases of the conv layer - num_biases = math.prod(conv.bias.shape) - # Load weights - conv_biases = Tensor(weights[ptr: ptr+num_biases]) - ptr += num_biases - # Reshape - conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape)) - # Copy - conv.bias = conv_biases - # Load weighys for conv layers - num_weights = math.prod(conv.weight.shape) - conv_weights = Tensor(weights[ptr:ptr+num_weights]) - ptr += num_weights - conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape)) - conv.weight = conv_weights + def load_weights(self, url): + weights = np.frombuffer(fetch(url), dtype=np.float32)[5:] + ptr = 0 + for i in range(len(self.module_list)): + module_type = self.blocks[i + 1]["type"] + if module_type == "convolutional": + model = self.module_list[i] + try: # we have batchnorm, load conv weights without biases, and batchnorm values + batch_normalize = int(self.blocks[i + 1]["batch_normalize"]) + except: # no batchnorm, load conv weights + biases + batch_normalize = 0 + conv = model[0] + if batch_normalize: + bn = model[1] + # Get the number of weights of batchnorm + num_bn_biases = math.prod(bn.bias.shape) + # Load weights + bn_biases = Tensor(weights[ptr : ptr + num_bn_biases]) + ptr += num_bn_biases + bn_weights = Tensor(weights[ptr : ptr + num_bn_biases]) + ptr += num_bn_biases + bn_running_mean = Tensor(weights[ptr : ptr + num_bn_biases]) + ptr += num_bn_biases + bn_running_var = Tensor(weights[ptr : ptr + num_bn_biases]) + ptr += num_bn_biases + # Cast the loaded weights into dims of model weights + bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape)) + bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape)) + bn_running_mean = bn_running_mean.reshape( + shape=tuple(bn.running_mean.shape) + ) + bn_running_var = bn_running_var.reshape( + shape=tuple(bn.running_var.shape) + ) + # Copy data + bn.bias = bn_biases + bn.weight = bn_weights + bn.running_mean = bn_running_mean + bn.running_var = bn_running_var + else: + # load biases of the conv layer + num_biases = math.prod(conv.bias.shape) + # Load weights + conv_biases = Tensor(weights[ptr : ptr + num_biases]) + ptr += num_biases + # Reshape + conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape)) + # Copy + conv.bias = conv_biases + # Load weighys for conv layers + num_weights = math.prod(conv.weight.shape) + conv_weights = Tensor(weights[ptr : ptr + num_weights]) + ptr += num_weights + conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape)) + conv.weight = conv_weights + + def forward(self, x): + modules = self.blocks[1:] + outputs = {} # Cached outputs for route layer + detections, write = None, False + for i, module in enumerate(modules): + module_type = module["type"] + if module_type == "convolutional" or module_type == "upsample": + for layer in self.module_list[i]: + x = layer(x) + elif module_type == "route": + layers = module["layers"] + layers = [int(a) for a in layers] + if (layers[0]) > 0: + layers[0] = layers[0] - i + if len(layers) == 1: + x = outputs[i + (layers[0])] + else: + if (layers[1]) > 0: + layers[1] = layers[1] - i + map1 = outputs[i + layers[0]] + map2 = outputs[i + layers[1]] + x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1)) + elif module_type == "shortcut": + from_ = int(module["from"]) + x = outputs[i - 1] + outputs[i + from_] + elif module_type == "yolo": + anchors = self.module_list[i][0] + inp_dim = int(self.net_info["height"]) # 416 + num_classes = int(module["classes"]) + x = predict_transform(x, inp_dim, anchors, num_classes) + if not write: + detections, write = x, True + else: + detections = Tensor( + np.concatenate((detections.numpy(), x.numpy()), axis=1) + ) + outputs[i] = x + return detections - def forward(self, x): - modules = self.blocks[1:] - outputs = {} # Cached outputs for route layer - detections, write = None, False - for i, module in enumerate(modules): - module_type = (module["type"]) - if module_type == "convolutional" or module_type == "upsample": - for layer in self.module_list[i]: - x = layer(x) - elif module_type == "route": - layers = module["layers"] - layers = [int(a) for a in layers] - if (layers[0]) > 0: - layers[0] = layers[0] - i - if len(layers) == 1: - x = outputs[i + (layers[0])] - else: - if (layers[1]) > 0: layers[1] = layers[1] - i - map1 = outputs[i + layers[0]] - map2 = outputs[i + layers[1]] - x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1)) - elif module_type == "shortcut": - from_ = int(module["from"]) - x = outputs[i - 1] + outputs[i + from_] - elif module_type == "yolo": - anchors = self.module_list[i][0] - inp_dim = int(self.net_info["height"]) # 416 - num_classes = int(module["classes"]) - x = predict_transform(x, inp_dim, anchors, num_classes) - if not write: - detections, write = x, True - else: - detections = Tensor(np.concatenate((detections.numpy(), x.numpy()), axis=1)) - outputs[i] = x - return detections if __name__ == "__main__": - model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg')) - print("Loading weights file (237MB). This might take a while…") - model.load_weights('https://pjreddie.com/media/files/yolov3.weights') - if len(sys.argv) > 1: - url = sys.argv[1] - else: - url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png" - if url == 'webcam': - cap = cv2.VideoCapture(0) - cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - while 1: - _ = cap.grab() # discard one frame to circumvent capture buffering - ret, frame = cap.read() - prediction = process_results(infer(model, frame)) - img = Image.fromarray(frame[:, :, [2,1,0]]) - boxes = add_boxes(np.array(img.resize((608, 608))), prediction) - boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR) - cv2.imshow('yolo', boxes) - if cv2.waitKey(1) & 0xFF == ord('q'): - break - cap.release() - cv2.destroyAllWindows() - elif url.startswith('http'): - img_stream = io.BytesIO(fetch(url)) - img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1) - else: - img = cv2.imread(url) - st = time.time() - print('running inference…') - prediction = infer(model, img) - print(f'did inference in {(time.time() - st):2f}s') - show_labels(prediction) - prediction = process_results(prediction) - boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction) - cv2.imwrite('boxes.jpg', boxes) + model = Darknet( + fetch( + "https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg" + ) + ) + print("Loading weights file (237MB). This might take a while…") + model.load_weights("https://pjreddie.com/media/files/yolov3.weights") + if len(sys.argv) > 1: + url = sys.argv[1] + else: + url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png" + if url == "webcam": + cap = cv2.VideoCapture(0) + cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + while 1: + _ = cap.grab() # discard one frame to circumvent capture buffering + ret, frame = cap.read() + prediction = process_results(infer(model, frame)) + img = Image.fromarray(frame[:, :, [2, 1, 0]]) + boxes = add_boxes(np.array(img.resize((608, 608))), prediction) + boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR) + cv2.imshow("yolo", boxes) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + cap.release() + cv2.destroyAllWindows() + elif url.startswith("http"): + img_stream = io.BytesIO(fetch(url)) + img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1) + else: + img = cv2.imread(url) + st = time.time() + print("running inference…") + prediction = infer(model, img) + print(f"did inference in {(time.time() - st):2f}s") + show_labels(prediction) + prediction = process_results(prediction) + boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction) + cv2.imwrite("boxes.jpg", boxes) diff --git a/examples/yolov8-onnx.py b/examples/yolov8-onnx.py index f75b5cb33..756e259ab 100644 --- a/examples/yolov8-onnx.py +++ b/examples/yolov8-onnx.py @@ -8,11 +8,14 @@ from tinygrad.tensor import Tensor os.chdir("/tmp") if not Path("yolov8n-seg.onnx").is_file(): - model = YOLO("yolov8n-seg.pt") - model.export(format="onnx", imgsz=[480,640]) + model = YOLO("yolov8n-seg.pt") + model.export(format="onnx", imgsz=[480, 640]) onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb")) # TODO: move get example inputs to onnx -input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} +input_shapes = { + inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) + for inp in onnx_model.graph.input +} print(input_shapes) run_onnx = get_run_onnx(onnx_model) -run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True) +run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True) diff --git a/examples/yolov8.py b/examples/yolov8.py index c01dc5c13..81df968eb 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -9,424 +9,646 @@ import time, sys from tinygrad.helpers import fetch from tinygrad.nn.state import safe_load, load_state_dict -#Model architecture from https://github.com/ultralytics/ultralytics/issues/189 -#The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this) +# Model architecture from https://github.com/ultralytics/ultralytics/issues/189 +# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this) + + +# Pre processing image functions. +def compute_transform( + image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32 +): + shape = image.shape[:2] # current shape [height, width] + new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + r = min(r, 1.0) if not scaleup else r + new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r))) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + dw, dh = (np.mod(dw, stride), np.mod(dh, stride)) if auto else (0.0, 0.0) + new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad + dw /= 2 + dh /= 2 + image = ( + cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) + if shape[::-1] != new_unpad + else image + ) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + image = cv2.copyMakeBorder( + image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) + return image -#Pre processing image functions. -def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32): - shape = image.shape[:2] # current shape [height, width] - new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape - r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) - r = min(r, 1.0) if not scaleup else r - new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r))) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] - dw, dh = (np.mod(dw, stride), np.mod(dh, stride)) if auto else (0.0, 0.0) - new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad - dw /= 2 - dh /= 2 - image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image - top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) - left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) - return image def preprocess(im, imgsz=640, model_stride=32, model_pt=True): - same_shapes = all(x.shape == im[0].shape for x in im) - auto = same_shapes and model_pt - im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im]) - im = Tensor.stack(im) if im.shape[0] > 1 else im - im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) - im /= 255 # 0 - 255 to 0.0 - 1.0 - return im + same_shapes = all(x.shape == im[0].shape for x in im) + auto = same_shapes and model_pt + im = Tensor( + [ + compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) + for x in im + ] + ) + im = Tensor.stack(im) if im.shape[0] > 1 else im + im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + # Post Processing functions def box_area(box): - return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + def box_iou(box1, box2): - lt = np.maximum(box1[:, None, :2], box2[:, :2]) - rb = np.minimum(box1[:, None, 2:], box2[:, 2:]) - wh = np.clip(rb - lt, 0, None) - inter = wh[:, :, 0] * wh[:, :, 1] - area1 = box_area(box1)[:, None] - area2 = box_area(box2)[None, :] - iou = inter / (area1 + area2 - inter) - return iou + lt = np.maximum(box1[:, None, :2], box2[:, :2]) + rb = np.minimum(box1[:, None, 2:], box2[:, 2:]) + wh = np.clip(rb - lt, 0, None) + inter = wh[:, :, 0] * wh[:, :, 1] + area1 = box_area(box1)[:, None] + area2 = box_area(box2)[None, :] + iou = inter / (area1 + area2 - inter) + return iou + def compute_nms(boxes, scores, iou_threshold): - order, keep = scores.argsort()[::-1], [] - while order.size > 0: - i = order[0] - keep.append(i) - if order.size == 1: - break - iou = box_iou(boxes[i][None, :], boxes[order[1:]]) - inds = np.where(iou.squeeze() <= iou_threshold)[0] - order = order[inds + 1] - return np.array(keep) + order, keep = scores.argsort()[::-1], [] + while order.size > 0: + i = order[0] + keep.append(i) + if order.size == 1: + break + iou = box_iou(boxes[i][None, :], boxes[order[1:]]) + inds = np.where(iou.squeeze() <= iou_threshold)[0] + order = order[inds + 1] + return np.array(keep) -def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680): - prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction - bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4) - xc = np.amax(prediction[:, 4:4 + nc], axis=1) > conf_thres - nm = prediction.shape[1] - nc - 4 - output = [np.zeros((0, 6 + nm))] * bs - for xi, x in enumerate(prediction): - x = x.swapaxes(0, -1)[xc[xi]] - if not x.shape[0]: continue - box, cls, mask = np.split(x, [4, 4 + nc], axis=1) - conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True) - x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1) - x = x[conf.ravel() > conf_thres] - if not x.shape[0]: continue - x = x[np.argsort(-x[:, 4])] - c = x[:, 5:6] * (0 if agnostic else max_wh) - boxes, scores = x[:, :4] + c, x[:, 4] - i = compute_nms(boxes, scores, iou_thres)[:max_det] - output[xi] = x[i] - return output +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + agnostic=False, + max_det=300, + nc=0, + max_wh=7680, +): + prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction + bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4) + xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres + nm = prediction.shape[1] - nc - 4 + output = [np.zeros((0, 6 + nm))] * bs + + for xi, x in enumerate(prediction): + x = x.swapaxes(0, -1)[xc[xi]] + if not x.shape[0]: + continue + box, cls, mask = np.split(x, [4, 4 + nc], axis=1) + conf, j = np.max(cls, axis=1, keepdims=True), np.argmax( + cls, axis=1, keepdims=True + ) + x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1) + x = x[conf.ravel() > conf_thres] + if not x.shape[0]: + continue + x = x[np.argsort(-x[:, 4])] + c = x[:, 5:6] * (0 if agnostic else max_wh) + boxes, scores = x[:, :4] + c, x[:, 4] + i = compute_nms(boxes, scores, iou_thres)[:max_det] + output[xi] = x[i] + return output + def postprocess(preds, img, orig_imgs): - print('copying to CPU now for post processing') - #if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though. - # TODO: make non_max_suppression in tinygrad - to make this faster - preds = preds.numpy() if isinstance(preds, Tensor) else preds - preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300) - all_preds = [] - for i, pred in enumerate(preds): - orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs - if not isinstance(orig_imgs, Tensor): - pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) - all_preds.append(pred) - return all_preds + print("copying to CPU now for post processing") + # if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though. + # TODO: make non_max_suppression in tinygrad - to make this faster + preds = preds.numpy() if isinstance(preds, Tensor) else preds + preds = non_max_suppression( + prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300 + ) + all_preds = [] + for i, pred in enumerate(preds): + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs + if not isinstance(orig_imgs, Tensor): + pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + all_preds.append(pred) + return all_preds -def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5): - color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)} - font = cv2.FONT_HERSHEY_SIMPLEX - def is_bright_color(color): - r, g, b = color - brightness = (r * 299 + g * 587 + b * 114) / 1000 - return brightness > 127 +def draw_bounding_boxes_and_save( + orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5 +): + color_dict = { + label: tuple( + (((i + 1) * 50) % 256, ((i + 1) * 100) % 256, ((i + 1) * 150) % 256) + ) + for i, label in enumerate(class_labels) + } + font = cv2.FONT_HERSHEY_SIMPLEX - for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)): - predictions = np.array(predictions) - orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1) - height, width, _ = orig_img.shape - box_thickness = int((height + width) / 400) - font_scale = (height + width) / 2500 + def is_bright_color(color): + r, g, b = color + brightness = (r * 299 + g * 587 + b * 114) / 1000 + return brightness > 127 - grouped_preds = defaultdict(list) - object_count = defaultdict(int) + for img_idx, (orig_img_path, output_img_path, predictions) in enumerate( + zip(orig_img_paths, output_img_paths, all_predictions) + ): + predictions = np.array(predictions) + orig_img = ( + cv2.imread(orig_img_path) + if not isinstance(orig_img_path, np.ndarray) + else cv2.imdecode(orig_img_path, 1) + ) + height, width, _ = orig_img.shape + box_thickness = int((height + width) / 400) + font_scale = (height + width) / 2500 - for pred_np in predictions: - grouped_preds[int(pred_np[-1])].append(pred_np) + grouped_preds = defaultdict(list) + object_count = defaultdict(int) - def draw_box_and_label(pred, color): - x1, y1, x2, y2, conf, _ = pred - x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) - cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness) - label = f"{class_labels[class_id]} {conf:.2f}" - text_size, _ = cv2.getTextSize(label, font, font_scale, 1) - label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1) - cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1) - font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255) - cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA) + for pred_np in predictions: + grouped_preds[int(pred_np[-1])].append(pred_np) - for class_id, pred_list in grouped_preds.items(): - pred_list = np.array(pred_list) - while len(pred_list) > 0: - max_conf_idx = np.argmax(pred_list[:, 4]) - max_conf_pred = pred_list[max_conf_idx] - pred_list = np.delete(pred_list, max_conf_idx, axis=0) - color = color_dict[class_labels[class_id]] - draw_box_and_label(max_conf_pred, color) - object_count[class_labels[class_id]] += 1 - iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4]) - low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0] - pred_list = pred_list[low_iou_indices] - for low_conf_pred in pred_list: - draw_box_and_label(low_conf_pred, color) + def draw_box_and_label(pred, color): + x1, y1, x2, y2, conf, _ = pred + x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) + cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness) + label = f"{class_labels[class_id]} {conf:.2f}" + text_size, _ = cv2.getTextSize(label, font, font_scale, 1) + label_y, bg_y = ( + (y1 - 4, y1 - text_size[1] - 4) + if y1 - text_size[1] - 4 > 0 + else (y1 + text_size[1], y1) + ) + cv2.rectangle( + orig_img, + (x1, bg_y), + (x1 + text_size[0], bg_y + text_size[1]), + color, + -1, + ) + font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255) + cv2.putText( + orig_img, + label, + (x1, label_y), + font, + font_scale, + font_color, + 1, + cv2.LINE_AA, + ) - print(f"Image {img_idx + 1}:") - print("Objects detected:") - for obj, count in object_count.items(): - print(f"- {obj}: {count}") + for class_id, pred_list in grouped_preds.items(): + pred_list = np.array(pred_list) + while len(pred_list) > 0: + max_conf_idx = np.argmax(pred_list[:, 4]) + max_conf_pred = pred_list[max_conf_idx] + pred_list = np.delete(pred_list, max_conf_idx, axis=0) + color = color_dict[class_labels[class_id]] + draw_box_and_label(max_conf_pred, color) + object_count[class_labels[class_id]] += 1 + iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4]) + low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0] + pred_list = pred_list[low_iou_indices] + for low_conf_pred in pred_list: + draw_box_and_label(low_conf_pred, color) + + print(f"Image {img_idx + 1}:") + print("Objects detected:") + for obj, count in object_count.items(): + print(f"- {obj}: {count}") + + cv2.imwrite(output_img_path, orig_img) + print(f"saved detections at {output_img_path}") - cv2.imwrite(output_img_path, orig_img) - print(f'saved detections at {output_img_path}') # utility functions for forward pass. def dist2bbox(distance, anchor_points, xywh=True, dim=-1): - lt, rb = distance.chunk(2, dim) - x1y1 = anchor_points - lt - x2y2 = anchor_points + rb - if xywh: - c_xy = (x1y1 + x2y2) / 2 - wh = x2y2 - x1y1 - return c_xy.cat(wh, dim=1) - return x1y1.cat(x2y2, dim=1) + lt, rb = distance.chunk(2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return c_xy.cat(wh, dim=1) + return x1y1.cat(x2y2, dim=1) + def make_anchors(feats, strides, grid_cell_offset=0.5): - anchor_points, stride_tensor = [], [] - assert feats is not None - for i, stride in enumerate(strides): - _, _, h, w = feats[i].shape - sx = Tensor.arange(w) + grid_cell_offset - sy = Tensor.arange(h) + grid_cell_offset + anchor_points, stride_tensor = [], [] + assert feats is not None + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = Tensor.arange(w) + grid_cell_offset + sy = Tensor.arange(h) + grid_cell_offset - # this is np.meshgrid but in tinygrad - sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1) - sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1) + # this is np.meshgrid but in tinygrad + sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1) + sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1) + + anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2)) + stride_tensor.append(Tensor.full((h * w), stride)) + anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2]) + stride_tensor = ( + stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1) + ) + return anchor_points, stride_tensor - anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2)) - stride_tensor.append(Tensor.full((h * w), stride)) - anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2]) - stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1) - return anchor_points, stride_tensor # this function is from the original implementation def autopad(k, p=None, d=1): # kernel, padding, dilation - if d > 1: - k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size - if p is None: - p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad - return p + if d > 1: + k = ( + d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] + ) # actual kernel-size + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + def clip_boxes(boxes, shape): - boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2 - boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2 - return boxes + boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2 + return boxes + def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): - gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) - pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2) - boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes - boxes_np[..., [0, 2]] -= pad[0] - boxes_np[..., [1, 3]] -= pad[1] - boxes_np[..., :4] /= gain - boxes_np = clip_boxes(boxes_np, img0_shape) - return boxes_np + gain = ( + ratio_pad + if ratio_pad + else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) + ) + pad = ( + (img1_shape[1] - img0_shape[1] * gain) / 2, + (img1_shape[0] - img0_shape[0] * gain) / 2, + ) + boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes + boxes_np[..., [0, 2]] -= pad[0] + boxes_np[..., [1, 3]] -= pad[1] + boxes_np[..., :4] /= gain + boxes_np = clip_boxes(boxes_np, img0_shape) + return boxes_np + def xywh2xyxy(x): - xy = x[..., :2] # center x, y - wh = x[..., 2:4] # width, height - xy1 = xy - wh / 2 # top left x, y - xy2 = xy + wh / 2 # bottom right x, y - result = np.concatenate((xy1, xy2), axis=-1) - return Tensor(result) if isinstance(x, Tensor) else result + xy = x[..., :2] # center x, y + wh = x[..., 2:4] # width, height + xy1 = xy - wh / 2 # top left x, y + xy2 = xy + wh / 2 # bottom right x, y + result = np.concatenate((xy1, xy2), axis=-1) + return Tensor(result) if isinstance(x, Tensor) else result + def get_variant_multiples(variant): - return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None) + return { + "n": (0.33, 0.25, 2.0), + "s": (0.33, 0.50, 2.0), + "m": (0.67, 0.75, 1.5), + "l": (1.0, 1.0, 1.0), + "x": (1, 1.25, 1.0), + }.get(variant, None) + def label_predictions(all_predictions): - class_index_count = defaultdict(int) - for predictions in all_predictions: - predictions = np.array(predictions) - for pred_np in predictions: - class_id = int(pred_np[-1]) - class_index_count[class_id] += 1 + class_index_count = defaultdict(int) + for predictions in all_predictions: + predictions = np.array(predictions) + for pred_np in predictions: + class_id = int(pred_np[-1]) + class_index_count[class_id] += 1 - return dict(class_index_count) + return dict(class_index_count) -#this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling) + +# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling) class Upsample: - def __init__(self, scale_factor:int, mode: str = "nearest") -> None: - assert mode == "nearest" # only mode supported for now - self.mode = mode - self.scale_factor = scale_factor + def __init__(self, scale_factor: int, mode: str = "nearest") -> None: + assert mode == "nearest" # only mode supported for now + self.mode = mode + self.scale_factor = scale_factor + + def __call__(self, x: Tensor) -> Tensor: + assert len(x.shape) > 2 and len(x.shape) <= 5 + (b, c), _lens = x.shape[:2], len(x.shape[2:]) + tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones( + *[1, 1, 1] + [self.scale_factor] * _lens + ) + return ( + tmp.reshape(list(x.shape) + [self.scale_factor] * _lens) + .permute( + [0, 1] + + list( + chain.from_iterable([[y + 2, y + 2 + _lens] for y in range(_lens)]) + ) + ) + .reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]]) + ) - def __call__(self, x: Tensor) -> Tensor: - assert len(x.shape) > 2 and len(x.shape) <= 5 - (b, c), _lens = x.shape[:2], len(x.shape[2:]) - tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens) - return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]]) class Conv_Block: - def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None): - self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation) - self.bn = BatchNorm2d(c2, eps=0.001) + def __init__( + self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None + ): + self.conv = Conv2d( + c1, + c2, + kernel_size, + stride, + padding=autopad(kernel_size, padding, dilation), + bias=False, + groups=groups, + dilation=dilation, + ) + self.bn = BatchNorm2d(c2, eps=0.001) + + def __call__(self, x): + return self.bn(self.conv(x)).silu() - def __call__(self, x): - return self.bn(self.conv(x)).silu() class Bottleneck: - def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5): - c_ = int(c2 * channel_factor) - self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None) - self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g) - self.residual = c1 == c2 and shortcut + def __init__( + self, c1, c2, shortcut: bool, g=1, kernels: list = (3, 3), channel_factor=0.5 + ): + c_ = int(c2 * channel_factor) + self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None) + self.cv2 = Conv_Block( + c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g + ) + self.residual = c1 == c2 and shortcut + + def __call__(self, x): + return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x)) - def __call__(self, x): - return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x)) class C2f: - def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): - self.c = int(c2 * e) - self.cv1 = Conv_Block(c1, 2 * self.c, 1,) - self.cv2 = Conv_Block((2 + n) * self.c, c2, 1) - self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)] + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + self.c = int(c2 * e) + self.cv1 = Conv_Block( + c1, + 2 * self.c, + 1, + ) + self.cv2 = Conv_Block((2 + n) * self.c, c2, 1) + self.bottleneck = [ + Bottleneck( + self.c, + self.c, + shortcut, + g, + kernels=[(3, 3), (3, 3)], + channel_factor=1.0, + ) + for _ in range(n) + ] + + def __call__(self, x): + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.bottleneck) + z = y[0] + for i in y[1:]: + z = z.cat(i, dim=1) + return self.cv2(z) - def __call__(self, x): - y= list(self.cv1(x).chunk(2, 1)) - y.extend(m(y[-1]) for m in self.bottleneck) - z = y[0] - for i in y[1:]: z = z.cat(i, dim=1) - return self.cv2(z) class SPPF: - def __init__(self, c1, c2, k=5): - c_ = c1 // 2 # hidden channels - self.cv1 = Conv_Block(c1, c_, 1, 1, padding=None) - self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None) + def __init__(self, c1, c2, k=5): + c_ = c1 // 2 # hidden channels + self.cv1 = Conv_Block(c1, c_, 1, 1, padding=None) + self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None) - # TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually. - self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1) + # TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually. + self.maxpool = lambda x: x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d( + kernel_size=k, stride=1 + ) + + def __call__(self, x): + x = self.cv1(x) + x2 = self.maxpool(x) + x3 = self.maxpool(x2) + x4 = self.maxpool(x3) + return self.cv2(x.cat(x2, x3, x4, dim=1)) - def __call__(self, x): - x = self.cv1(x) - x2 = self.maxpool(x) - x3 = self.maxpool(x2) - x4 = self.maxpool(x3) - return self.cv2(x.cat(x2, x3, x4, dim=1)) class DFL: - def __init__(self, c1=16): - self.conv = Conv2d(c1, 1, 1, bias=False) - x = Tensor.arange(c1) - self.conv.weight.assign(x.reshape(1, c1, 1, 1)) - self.c1 = c1 + def __init__(self, c1=16): + self.conv = Conv2d(c1, 1, 1, bias=False) + x = Tensor.arange(c1) + self.conv.weight.assign(x.reshape(1, c1, 1, 1)) + self.c1 = c1 - def __call__(self, x): - b, c, a = x.shape # batch, channels, anchors - return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a) + def __call__(self, x): + b, c, a = x.shape # batch, channels, anchors + return self.conv( + x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1) + ).reshape(b, 4, a) -#backbone + +# backbone class Darknet: - def __init__(self, w, r, d): - self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)] - self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)] - self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)] - self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)] - self.b5 = [SPPF(int(512*w*r), int(512*w*r), 5)] + def __init__(self, w, r, d): + self.b1 = [ + Conv_Block(c1=3, c2=int(64 * w), kernel_size=3, stride=2, padding=1), + Conv_Block(int(64 * w), int(128 * w), kernel_size=3, stride=2, padding=1), + ] + self.b2 = [ + C2f(c1=int(128 * w), c2=int(128 * w), n=round(3 * d), shortcut=True), + Conv_Block(int(128 * w), int(256 * w), 3, 2, 1), + C2f(int(256 * w), int(256 * w), round(6 * d), True), + ] + self.b3 = [ + Conv_Block(int(256 * w), int(512 * w), kernel_size=3, stride=2, padding=1), + C2f(int(512 * w), int(512 * w), round(6 * d), True), + ] + self.b4 = [ + Conv_Block( + int(512 * w), int(512 * w * r), kernel_size=3, stride=2, padding=1 + ), + C2f(int(512 * w * r), int(512 * w * r), round(3 * d), True), + ] + self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)] - def return_modules(self): - return [*self.b1, *self.b2, *self.b3, *self.b4, *self.b5] + def return_modules(self): + return [*self.b1, *self.b2, *self.b3, *self.b4, *self.b5] - def __call__(self, x): - x1 = x.sequential(self.b1) - x2 = x1.sequential(self.b2) - x3 = x2.sequential(self.b3) - x4 = x3.sequential(self.b4) - x5 = x4.sequential(self.b5) - return (x2, x3, x5) + def __call__(self, x): + x1 = x.sequential(self.b1) + x2 = x1.sequential(self.b2) + x3 = x2.sequential(self.b3) + x4 = x3.sequential(self.b4) + x5 = x4.sequential(self.b5) + return (x2, x3, x5) -#yolo fpn (neck) + +# yolo fpn (neck) class Yolov8NECK: - def __init__(self, w, r, d): #width_multiple, ratio_multiple, depth_multiple - self.up = Upsample(2, mode='nearest') - self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False) - self.n2 = C2f(c1=int(768*w), c2=int(256*w), n=round(3*d), shortcut=False) - self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1) - self.n4 = C2f(c1=int(768*w), c2=int(512*w), n=round(3*d), shortcut=False) - self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1) - self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False) + def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple + self.up = Upsample(2, mode="nearest") + self.n1 = C2f( + c1=int(512 * w * (1 + r)), c2=int(512 * w), n=round(3 * d), shortcut=False + ) + self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False) + self.n3 = Conv_Block( + c1=int(256 * w), c2=int(256 * w), kernel_size=3, stride=2, padding=1 + ) + self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False) + self.n5 = Conv_Block( + c1=int(512 * w), c2=int(512 * w), kernel_size=3, stride=2, padding=1 + ) + self.n6 = C2f( + c1=int(512 * w * (1 + r)), + c2=int(512 * w * r), + n=round(3 * d), + shortcut=False, + ) - def return_modules(self): - return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6] + def return_modules(self): + return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6] - def __call__(self, p3, p4, p5): - x = self.n1(self.up(p5).cat(p4, dim=1)) - head_1 = self.n2(self.up(x).cat(p3, dim=1)) - head_2 = self.n4(self.n3(head_1).cat(x, dim=1)) - head_3 = self.n6(self.n5(head_2).cat(p5, dim=1)) - return [head_1, head_2, head_3] + def __call__(self, p3, p4, p5): + x = self.n1(self.up(p5).cat(p4, dim=1)) + head_1 = self.n2(self.up(x).cat(p3, dim=1)) + head_2 = self.n4(self.n3(head_1).cat(x, dim=1)) + head_3 = self.n6(self.n5(head_2).cat(p5, dim=1)) + return [head_1, head_2, head_3] -#task specific head. + +# task specific head. class DetectionHead: - def __init__(self, nc=80, filters=()): - self.ch = 16 - self.nc = nc # number of classes - self.nl = len(filters) - self.no = nc + self.ch * 4 # - self.stride = [8, 16, 32] - c1 = max(filters[0], self.nc) - c2 = max((filters[0] // 4, self.ch * 4)) - self.dfl = DFL(self.ch) - self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters] - self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters] + def __init__(self, nc=80, filters=()): + self.ch = 16 + self.nc = nc # number of classes + self.nl = len(filters) + self.no = nc + self.ch * 4 # + self.stride = [8, 16, 32] + c1 = max(filters[0], self.nc) + c2 = max((filters[0] // 4, self.ch * 4)) + self.dfl = DFL(self.ch) + self.cv3 = [ + [Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] + for x in filters + ] + self.cv2 = [ + [Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] + for x in filters + ] + + def __call__(self, x): + for i in range(self.nl): + x[i] = x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1) + self.anchors, self.strides = ( + x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5) + ) + y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x] + x_cat = y[0].cat(y[1], y[2], dim=2) + box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :] + dbox = ( + dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) + * self.strides + ) + z = dbox.cat(cls.sigmoid(), dim=1) + return z - def __call__(self, x): - for i in range(self.nl): - x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1)) - self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) - y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x] - x_cat = y[0].cat(y[1], y[2], dim=2) - box, cls = x_cat[:, :self.ch * 4], x_cat[:, self.ch * 4:] - dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides - z = dbox.cat(cls.sigmoid(), dim=1) - return z class YOLOv8: - def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple - self.net = Darknet(w, r, d) - self.fpn = Yolov8NECK(w, r, d) - self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r))) + def __init__( + self, w, r, d, num_classes + ): # width_multiple, ratio_multiple, depth_multiple + self.net = Darknet(w, r, d) + self.fpn = Yolov8NECK(w, r, d) + self.head = DetectionHead( + num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r)) + ) - def __call__(self, x): - x = self.net(x) - x = self.fpn(*x) - return self.head(x) + def __call__(self, x): + x = self.net(x) + x = self.fpn(*x) + return self.head(x) - def return_all_trainable_modules(self): - backbone_modules = [*range(10)] - yolov8neck_modules = [12, 15, 16, 18, 19, 21] - yolov8_head_weights = [(22, self.head)] - return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights] + def return_all_trainable_modules(self): + backbone_modules = [*range(10)] + yolov8neck_modules = [12, 15, 16, 18, 19, 21] + yolov8_head_weights = [(22, self.head)] + return [ + *zip(backbone_modules, self.net.return_modules()), + *zip(yolov8neck_modules, self.fpn.return_modules()), + *yolov8_head_weights, + ] -if __name__ == '__main__': - # usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default) - if len(sys.argv) < 2: - print("Error: Image URL or path not provided.") - sys.exit(1) +if __name__ == "__main__": + # usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default) + if len(sys.argv) < 2: + print("Error: Image URL or path not provided.") + sys.exit(1) - img_path = sys.argv[1] - yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n') - print(f'running inference for YOLO version {yolo_variant}') + img_path = sys.argv[1] + yolo_variant = ( + sys.argv[2] + if len(sys.argv) >= 3 + else ( + print( + "No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']" + ) + or "n" + ) + ) + print(f"running inference for YOLO version {yolo_variant}") - output_folder_path = Path('./outputs_yolov8') - output_folder_path.mkdir(parents=True, exist_ok=True) - #absolute image path or URL - image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)] - image = [cv2.imdecode(image_location[0], 1)] - out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()] - if not isinstance(image[0], np.ndarray): - print('Error in image loading. Check your image file.') - sys.exit(1) - pre_processed_image = preprocess(image) + output_folder_path = Path("./outputs_yolov8") + output_folder_path.mkdir(parents=True, exist_ok=True) + # absolute image path or URL + image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)] + image = [cv2.imdecode(image_location[0], 1)] + out_paths = [ + ( + output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}" + ).as_posix() + ] + if not isinstance(image[0], np.ndarray): + print("Error in image loading. Check your image file.") + sys.exit(1) + pre_processed_image = preprocess(image) - # Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/v8/yolov8.yaml - depth, width, ratio = get_variant_multiples(yolo_variant) - yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + # Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/v8/yolov8.yaml + depth, width, ratio = get_variant_multiples(yolo_variant) + yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors')) - load_state_dict(yolo_infer, state_dict) + state_dict = safe_load( + fetch( + f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors" + ) + ) + load_state_dict(yolo_infer, state_dict) - st = time.time() - predictions = yolo_infer(pre_processed_image) - print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms') + st = time.time() + predictions = yolo_infer(pre_processed_image) + print(f"did inference in {int(round(((time.time() - st) * 1000)))}ms") - post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image) + post_predictions = postprocess( + preds=predictions, img=pre_processed_image, orig_imgs=image + ) - #v8 and v3 have same 80 class names for Object Detection - class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n") + # v8 and v3 have same 80 class names for Object Detection + class_labels = ( + fetch( + "https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names" + ) + .read_text() + .split("\n") + ) - draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels) + draw_bounding_boxes_and_save( + orig_img_paths=image_location, + output_img_paths=out_paths, + all_predictions=post_predictions, + class_labels=class_labels, + ) # TODO for later: # 1. Fix SPPF minor difference due to maxpool # 2. AST exp overflow warning while on cpu # 3. Make NMS faster -# 4. Add video inference and webcam support \ No newline at end of file +# 4. Add video inference and webcam support diff --git a/extra/accel/ane/1_build/coreml_ane.py b/extra/accel/ane/1_build/coreml_ane.py index cefd14391..e44a9e4a5 100755 --- a/extra/accel/ane/1_build/coreml_ane.py +++ b/extra/accel/ane/1_build/coreml_ane.py @@ -6,25 +6,32 @@ from coremltools.models.neural_network import datatypes, NeuralNetworkBuilder # KxK GEMM with bias K = 64 -input_features = [('image', datatypes.Array(K))] -input_features2 = [('image2', datatypes.Array(K))] -output_features = [('probs', datatypes.Array(K))] +input_features = [("image", datatypes.Array(K))] +input_features2 = [("image2", datatypes.Array(K))] +output_features = [("probs", datatypes.Array(K))] weights = np.zeros((K, K)) + 3 bias = np.ones(K) -builder = NeuralNetworkBuilder(input_features+input_features2, output_features) +builder = NeuralNetworkBuilder(input_features + input_features2, output_features) -#builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med') -#builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs') -builder.add_elementwise(name='element', input_names=['image', 'image2'], output_name='probs', mode='ADD') -#builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,)) -#builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs') +# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med') +# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs') +builder.add_elementwise( + name="element", input_names=["image", "image2"], output_name="probs", mode="ADD" +) +# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,)) +# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs') # compile the spec mlmodel = ct.models.MLModel(builder.spec) # trigger the ANE! -out = mlmodel.predict({"image": np.zeros(K, dtype=np.float32)+1, "image2": np.zeros(K, dtype=np.float32)+2}) +out = mlmodel.predict( + { + "image": np.zeros(K, dtype=np.float32) + 1, + "image2": np.zeros(K, dtype=np.float32) + 2, + } +) print(out) -mlmodel.save('test.mlmodel') +mlmodel.save("test.mlmodel") diff --git a/extra/accel/ane/2_compile/dcompile.py b/extra/accel/ane/2_compile/dcompile.py index 7afdbf588..b717854ee 100755 --- a/extra/accel/ane/2_compile/dcompile.py +++ b/extra/accel/ane/2_compile/dcompile.py @@ -5,13 +5,13 @@ import networkx as nx import pylab as plt from networkx.drawing.nx_pydot import read_dot -ret = os.system("./a.out "+sys.argv[1]+" debug") -assert(ret == 0) +ret = os.system("./a.out " + sys.argv[1] + " debug") +assert ret == 0 df = "debug/model.hwx.zinir_graph_after_reg_spill.dot" -#from graphviz import render -#render('dot', 'png', df) +# from graphviz import render +# render('dot', 'png', df) -#plt = Image(pdot.create_png() -#display(plt) +# plt = Image(pdot.create_png() +# display(plt) diff --git a/extra/accel/ane/2_compile/hwx_parse.py b/extra/accel/ane/2_compile/hwx_parse.py index 88ba4ea22..a3269c608 100755 --- a/extra/accel/ane/2_compile/hwx_parse.py +++ b/extra/accel/ane/2_compile/hwx_parse.py @@ -3,138 +3,155 @@ import sys from hexdump import hexdump from macholib import MachO from tinygrad.helpers import getenv + + def get_macho(fn): - # mod to make the header okay - # MH_CIGAM_64 is good - dat = open(fn, "rb").read() - dat = b"\xcf\xfa\xed\xfe"+dat[4:] - from tempfile import NamedTemporaryFile - with NamedTemporaryFile(delete=False) as f: - f.write(dat) - f.close() - return MachO.MachO(f.name) + # mod to make the header okay + # MH_CIGAM_64 is good + dat = open(fn, "rb").read() + dat = b"\xcf\xfa\xed\xfe" + dat[4:] + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile(delete=False) as f: + f.write(dat) + f.close() + return MachO.MachO(f.name) + a = get_macho("model.hwx.golden") # load commands for c in a.headers[0].commands: - print("command", c[0], c[1]) - if c[0].cmd == 4: - hexdump(c[2]) - pass - if c[0].cmd == 6: - print("name:", c[2].decode('utf-8')) - if c[0].cmd == 8: - print(c[2].decode('utf-8')) - if c[0].cmd == 25: - for section in c[2]: - print(section.segname.strip(b'\0'), section.sectname.strip(b'\0'), hex(section.addr), hex(section.size), "@", hex(c[1].fileoff)) - #print(dir(section)) - if c[1].filesize > 0: - if len(section.section_data) < 0x100: - hexdump(section.section_data) - else: - print("in file, not dumping 0x%x" % len(section.section_data)) + print("command", c[0], c[1]) + if c[0].cmd == 4: + hexdump(c[2]) + pass + if c[0].cmd == 6: + print("name:", c[2].decode("utf-8")) + if c[0].cmd == 8: + print(c[2].decode("utf-8")) + if c[0].cmd == 25: + for section in c[2]: + print( + section.segname.strip(b"\0"), + section.sectname.strip(b"\0"), + hex(section.addr), + hex(section.size), + "@", + hex(c[1].fileoff), + ) + # print(dir(section)) + if c[1].filesize > 0: + if len(section.section_data) < 0x100: + hexdump(section.section_data) + else: + print("in file, not dumping 0x%x" % len(section.section_data)) # this parser is wrong (fixed with 64-bit one) from macholib import SymbolTable + sym = SymbolTable.SymbolTable(a) syms = {} for l in sym.nlists: - print(l) - if l[0].n_value != 0: - syms[l[1]] = l[0].n_value + print(l) + if l[0].n_value != 0: + syms[l[1]] = l[0].n_value -for k,v in syms.items(): - print(k, hex(v)) +for k, v in syms.items(): + print(k, hex(v)) # **** document what we know *** from ane import ANE_Struct, ANE + ane = ANE() aneb = set() for typ, num, nam in ANE_Struct: - ltyp = {"u32": 4, "u16": 2, "u8": 1}[typ] - for l in range(num, num+ltyp): - aneb.add(l) + ltyp = {"u32": 4, "u16": 2, "u8": 1}[typ] + for l in range(num, num + ltyp): + aneb.add(l) # we understand these too for l in range(0x34, 0xF4): - aneb.add(l) + aneb.add(l) from termcolor import colored + + def compare(x, y): - ss = [] - ln = [] - ln2 = [] + ss = [] + ln = [] + ln2 = [] - ll = (max(len(x), len(y)) + 0xF)//0x10 * 0x10 + ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10 - highlight = False - next_highlight = 0x2b - for i in range(ll+1): - if i == next_highlight: - highlight = True - if i < len(y): - next_highlight += y[i]+8 - else: - next_highlight = None - else: - highlight = False - a = "%02X" % x[i] if i < len(x) else "--", \ - "%02X" % y[i] if i < len(y) else "--" - def fj(x): - ss = [] - for i in range(0, 0x10, 4): - ss.append(' '.join(x[i:i+4])) - return ' '.join(ss) - - if i!=0 and i%0x10 == 0: - ss.append("%8X: " % (i-0x10)+fj(ln)+" | "+fj(ln2)+"\n") - ln = [] - ln2 = [] - if a[0] != a[1] and a[0] != "--" and a[1] != "--": - ln.append(colored(a[0], 'green')) - ln2.append(colored(a[1], 'red')) - else: - if highlight: - ln.append(colored(a[0], 'yellow')) - ln2.append(colored(a[1], 'yellow')) - else: - if i in aneb: - ln.append(colored(a[0], 'white')) - ln2.append(colored(a[1], 'white')) + highlight = False + next_highlight = 0x2B + for i in range(ll + 1): + if i == next_highlight: + highlight = True + if i < len(y): + next_highlight += y[i] + 8 + else: + next_highlight = None else: - ln.append(a[0]) - ln2.append(a[1]) - return ''.join(ss) + highlight = False + a = "%02X" % x[i] if i < len(x) else "--", "%02X" % y[i] if i < len(y) else "--" + + def fj(x): + ss = [] + for i in range(0, 0x10, 4): + ss.append(" ".join(x[i : i + 4])) + return " ".join(ss) + + if i != 0 and i % 0x10 == 0: + ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n") + ln = [] + ln2 = [] + if a[0] != a[1] and a[0] != "--" and a[1] != "--": + ln.append(colored(a[0], "green")) + ln2.append(colored(a[1], "red")) + else: + if highlight: + ln.append(colored(a[0], "yellow")) + ln2.append(colored(a[1], "yellow")) + else: + if i in aneb: + ln.append(colored(a[0], "white")) + ln2.append(colored(a[1], "white")) + else: + ln.append(a[0]) + ln2.append(a[1]) + return "".join(ss) + import json + aneregs = dict(json.load(open("aneregs.json"))) g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1]) f1 = g.headers[0].commands[1][2][0].section_data f2 = a.headers[0].commands[1][2][0].section_data for i in range(0, len(f2), 0x300): - print("===== op %d =====" % (i//0x300)) - if len(f1) < 0x300: - c1, c2 = f1, f2[i:i+0x300] - else: - c1, c2 = f1[i:i+0x300], f2[i:i+0x300] - dbg1 = ane.debug(c1, 16) - dbg2 = ane.debug(c2, 16) - if getenv("PRINTALL"): - for k in dbg2: - if k in aneregs: - rr = aneregs[k] if k in aneregs else (-1,-1,-1) - print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k]) - else: - for k in dbg1: - if dbg1[k] != dbg2[k]: - rr = aneregs[k] if k in aneregs else (-1,-1,-1) - print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k]) + print("===== op %d =====" % (i // 0x300)) + if len(f1) < 0x300: + c1, c2 = f1, f2[i : i + 0x300] + else: + c1, c2 = f1[i : i + 0x300], f2[i : i + 0x300] + dbg1 = ane.debug(c1, 16) + dbg2 = ane.debug(c2, 16) + if getenv("PRINTALL"): + for k in dbg2: + if k in aneregs: + rr = aneregs[k] if k in aneregs else (-1, -1, -1) + print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k]) + else: + for k in dbg1: + if dbg1[k] != dbg2[k]: + rr = aneregs[k] if k in aneregs else (-1, -1, -1) + print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k]) - print(compare(c1, c2)) -#open("/tmp/data.section", "wb").write(f2) -#print(compare(open("model.hwx.golden", "rb").read(), open("model.hwx", "rb").read())) + print(compare(c1, c2)) +# open("/tmp/data.section", "wb").write(f2) +# print(compare(open("model.hwx.golden", "rb").read(), open("model.hwx", "rb").read())) diff --git a/extra/accel/ane/2_compile/struct_recover.py b/extra/accel/ane/2_compile/struct_recover.py index 86e859bc3..07fe49cff 100755 --- a/extra/accel/ane/2_compile/struct_recover.py +++ b/extra/accel/ane/2_compile/struct_recover.py @@ -1,36 +1,37 @@ #!/usr/bin/env python3 from ane import ANE + ane = ANE() lens = {} -dat = b"\xff"*0x300 +dat = b"\xff" * 0x300 ret = ane.debug(dat, 16) -for k,v in ret.items(): - found = None - for i in range(33): - #print(v, (1 << i) - 1) - if v == (1 << i) - 1: - found = i - break - #print(k, hex(v), found) - lens[k] = found +for k, v in ret.items(): + found = None + for i in range(33): + # print(v, (1 << i) - 1) + if v == (1 << i) - 1: + found = i + break + # print(k, hex(v), found) + lens[k] = found pos = [] -dat = b"\x00"*0x300 +dat = b"\x00" * 0x300 for i in range(0x300): - for j in range(8): - dat = b"\x00"*i - dat += bytes([1 << j]) - dat += b"\x00"*(0x300-len(dat)) - ret = ane.debug(dat, 16) - for k,v in ret.items(): - if v == 1: - print("0x%3x %d %2d" % (i, j, lens[k]), k) - pos.append((k, (i,j, lens[k]))) + for j in range(8): + dat = b"\x00" * i + dat += bytes([1 << j]) + dat += b"\x00" * (0x300 - len(dat)) + ret = ane.debug(dat, 16) + for k, v in ret.items(): + if v == 1: + print("0x%3x %d %2d" % (i, j, lens[k]), k) + pos.append((k, (i, j, lens[k]))) import json + jpos = json.dumps(pos, indent=2) with open("aneregs.json", "w") as f: - f.write(jpos) - + f.write(jpos) diff --git a/extra/accel/ane/amfi/new_patch.py b/extra/accel/ane/amfi/new_patch.py index 5fcd1d379..abbb3ad22 100644 --- a/extra/accel/ane/amfi/new_patch.py +++ b/extra/accel/ane/amfi/new_patch.py @@ -2,15 +2,18 @@ import ctypes from subprocess import check_output from hexdump import hexdump + def get_pid(name): - try: - output = check_output(["pgrep", name]) - return int(output) - except: - return None + try: + output = check_output(["pgrep", name]) + return int(output) + except: + return None + from ctypes.util import find_library -libc = ctypes.CDLL(find_library('c')) + +libc = ctypes.CDLL(find_library("c")) amfid_pid = get_pid("amfid") @@ -19,25 +22,28 @@ mytask = libc.mach_task_self() ret = libc.task_for_pid(mytask, ctypes.c_int(amfid_pid), ctypes.pointer(task)) print(amfid_pid, ret, task, mytask) -#myport = libc.mach_task_self() +# myport = libc.mach_task_self() + class vm_region_submap_short_info_data_64(ctypes.Structure): - _pack_ = 1 - _fields_ = [ - ("protection", ctypes.c_uint32), - ("max_protection", ctypes.c_uint32), - ("inheritance", ctypes.c_uint32), - ("offset", ctypes.c_ulonglong), - ("user_tag", ctypes.c_uint32), - ("ref_count", ctypes.c_uint32), - ("shadow_depth", ctypes.c_uint16), - ("external_pager", ctypes.c_byte), - ("share_mode", ctypes.c_byte), - ("is_submap", ctypes.c_uint32), - ("behavior", ctypes.c_uint32), - ("object_id", ctypes.c_uint32), - ("user_wired_count", ctypes.c_uint32), - ] + _pack_ = 1 + _fields_ = [ + ("protection", ctypes.c_uint32), + ("max_protection", ctypes.c_uint32), + ("inheritance", ctypes.c_uint32), + ("offset", ctypes.c_ulonglong), + ("user_tag", ctypes.c_uint32), + ("ref_count", ctypes.c_uint32), + ("shadow_depth", ctypes.c_uint16), + ("external_pager", ctypes.c_byte), + ("share_mode", ctypes.c_byte), + ("is_submap", ctypes.c_uint32), + ("behavior", ctypes.c_uint32), + ("object_id", ctypes.c_uint32), + ("user_wired_count", ctypes.c_uint32), + ] + + submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4 address = ctypes.c_ulong(0) @@ -48,27 +54,37 @@ depth = 0 c_depth = ctypes.c_uint32(depth) for i in range(1): - ret = libc.mach_vm_region_recurse(task, - ctypes.pointer(address), ctypes.pointer(mapsize), - ctypes.pointer(c_depth), ctypes.pointer(sub_info), - ctypes.pointer(count)) - print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection) - #address.value += mapsize.value -#exit(0) + ret = libc.mach_vm_region_recurse( + task, + ctypes.pointer(address), + ctypes.pointer(mapsize), + ctypes.pointer(c_depth), + ctypes.pointer(sub_info), + ctypes.pointer(count), + ) + print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection) + # address.value += mapsize.value +# exit(0) -patch_address = address.value + 0x8e38 +patch_address = address.value + 0x8E38 patch = b"\x00\x00\x80\xd2" pdata = ctypes.c_void_p(0) data_cnt = ctypes.c_uint32(0) -ret = libc.mach_vm_read(task, ctypes.c_ulong(patch_address), 4, ctypes.pointer(pdata), ctypes.pointer(data_cnt)) +ret = libc.mach_vm_read( + task, + ctypes.c_ulong(patch_address), + 4, + ctypes.pointer(pdata), + ctypes.pointer(data_cnt), +) buf = ctypes.string_at(pdata.value, data_cnt.value) hexdump(buf) -#ret = libc.mach_vm_wire(mytask, task, patch_address, 4, 3) -#print(ret) -#exit(0) +# ret = libc.mach_vm_wire(mytask, task, patch_address, 4, 3) +# print(ret) +# exit(0) """ ret = libc.mach_vm_read(task, address, mapsize, ctypes.pointer(pdata), ctypes.pointer(data_cnt)) @@ -86,17 +102,17 @@ ret = libc.mach_vm_protect(task, ctypes.c_ulong(patch_address), 4, True, 3) print("protect", ret) longptr = ctypes.POINTER(ctypes.c_ulong) -#shellcodePtr = ctypes.cast(buf, longptr) -#ret = libc.mach_vm_write(task, address, shellcodePtr, len(buf)) -#print("write", ret) +# shellcodePtr = ctypes.cast(buf, longptr) +# ret = libc.mach_vm_write(task, address, shellcodePtr, len(buf)) +# print("write", ret) shellcodePtr = ctypes.cast(patch, longptr) ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(buf)) print("write", ret) -#libc.mach_vm_write.argtypes = [ctypes.c_uint32, ctypes.c_ulong, longptr, ctypes.c_uint32] -#libc.mach_vm_write.restype = ctypes.c_uint32 -#ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(patch)) +# libc.mach_vm_write.argtypes = [ctypes.c_uint32, ctypes.c_ulong, longptr, ctypes.c_uint32] +# libc.mach_vm_write.restype = ctypes.c_uint32 +# ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(patch)) ret = libc.mach_vm_protect(task, ctypes.c_ulong(patch_address), 4, False, 5) -print("protect", ret) \ No newline at end of file +print("protect", ret) diff --git a/extra/accel/ane/lib/ane.py b/extra/accel/ane/lib/ane.py index 2e430c0f5..a1dda6d8e 100755 --- a/extra/accel/ane/lib/ane.py +++ b/extra/accel/ane/lib/ane.py @@ -6,217 +6,214 @@ import collections import numpy as np import faulthandler import struct + faulthandler.enable() basedir = Path(__file__).resolve().parent libane = None aneregs = None + + def init_libane(): - global libane, aneregs - libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix()) + global libane, aneregs + libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix()) - libane.ANE_Compile.argtypes = [c_char_p, c_int] - libane.ANE_Compile.restype = c_void_p + libane.ANE_Compile.argtypes = [c_char_p, c_int] + libane.ANE_Compile.restype = c_void_p - libane.ANE_TensorCreate.restype = c_void_p + libane.ANE_TensorCreate.restype = c_void_p - libane.ANE_TensorData.argtypes = [c_void_p] - libane.ANE_TensorData.restype = POINTER(c_uint16) + libane.ANE_TensorData.argtypes = [c_void_p] + libane.ANE_TensorData.restype = POINTER(c_uint16) - libane.ANE_Run.argtypes = [c_void_p]*4 - libane.ANE_Run.restype = c_int + libane.ANE_Run.argtypes = [c_void_p] * 4 + libane.ANE_Run.restype = c_int - #libane.ANE_RegDebug.restype = c_char_p + # libane.ANE_RegDebug.restype = c_char_p + + with open(basedir / "aneregs.json") as f: + aneregs = json.load(f) - with open(basedir / "aneregs.json") as f: - aneregs = json.load(f) ANE_Struct = [ -# aneTD.Header - ("u32", 0x1C, "NextCommandOffset"), - -# KernelDMASrc @ section @ 0x2C len 0xF4 - # reloc 0x2c-0x34?? = weights - # u32[16] 0x34-0x74 = 0x80 | 1 if used - # u32[16] 0x74-0xB4 = - # u32[16] 0xB4-0xF4 = - -# Common @ section @ 0x128 len 0x3C (conv) - ("u16", 0x128, "InputWidth"), - ("u16", 0x12A, "InputHeight"), - ("u16", 0x12C, "InputDepth"), - - ("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType - # UInt8 = 0, Int8 = 1, Float16 = 2 - - ("u32", 0x134, "InputChannels"), - ("u32", 0x138, "OutputChannels"), - - ("u16", 0x13C, "OutputWidth"), - ("u16", 0x13E, "OutputHeight"), - ("u16", 0x140, "OutputDepth"), - - ("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth - ("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2) - - ("u16", 0x14C, "BatchSize"), - -# TileDMASrc @ section @ 0x16C len 0x6C (input) - # reloc 0x16c-0x174 = image - ("u32", 0x178, "InputRowStride"), - ("u32", 0x17C, "InputPlaneStride"), - ("u32", 0x180, "InputDepthStride"), - ("u32", 0x184, "InputBatchStride"), - - ("u8", 0x1A7, "InputInterleave"), - -# L2 @ section @ 0x1E0 len 0x44 - # [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines - # [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv? - # [0x21c, 0x220, 0x224] = engines for outconv? - -# NE @ section @ 0x22c len 0xC (scaling) - ("u16", 0x230, "BiasScalar"), - ("u16", 0x232, "ScaleScalar"), - -# section @ 0x240 len 0x10 - ("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom - ("u32", 0x250, "PostScale"), - -# TileDMADst @ section @ 0x258 len 0x18 - -# HandleTileDmaDstConfig - # 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0; - # (GetCacheHintRegisterValue & 0xf) << 6; - ("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at? - - # 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6; - ("u32", 0x260, "OutputRowStride"), - ("u32", 0x264, "OutputPlaneStride"), - ("u32", 0x268, "OutputDepthStride"), - ("u32", 0x26C, "OutputBatchStride"), - - # 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000; - # uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031; - # (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18; - ("u8", 0x273, "OutputInterleave"), # i also have this at 0x211? + # aneTD.Header + ("u32", 0x1C, "NextCommandOffset"), + # KernelDMASrc @ section @ 0x2C len 0xF4 + # reloc 0x2c-0x34?? = weights + # u32[16] 0x34-0x74 = 0x80 | 1 if used + # u32[16] 0x74-0xB4 = + # u32[16] 0xB4-0xF4 = + # Common @ section @ 0x128 len 0x3C (conv) + ("u16", 0x128, "InputWidth"), + ("u16", 0x12A, "InputHeight"), + ("u16", 0x12C, "InputDepth"), + ("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType + # UInt8 = 0, Int8 = 1, Float16 = 2 + ("u32", 0x134, "InputChannels"), + ("u32", 0x138, "OutputChannels"), + ("u16", 0x13C, "OutputWidth"), + ("u16", 0x13E, "OutputHeight"), + ("u16", 0x140, "OutputDepth"), + ("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth + ("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2) + ("u16", 0x14C, "BatchSize"), + # TileDMASrc @ section @ 0x16C len 0x6C (input) + # reloc 0x16c-0x174 = image + ("u32", 0x178, "InputRowStride"), + ("u32", 0x17C, "InputPlaneStride"), + ("u32", 0x180, "InputDepthStride"), + ("u32", 0x184, "InputBatchStride"), + ("u8", 0x1A7, "InputInterleave"), + # L2 @ section @ 0x1E0 len 0x44 + # [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines + # [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv? + # [0x21c, 0x220, 0x224] = engines for outconv? + # NE @ section @ 0x22c len 0xC (scaling) + ("u16", 0x230, "BiasScalar"), + ("u16", 0x232, "ScaleScalar"), + # section @ 0x240 len 0x10 + ("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom + ("u32", 0x250, "PostScale"), + # TileDMADst @ section @ 0x258 len 0x18 + # HandleTileDmaDstConfig + # 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0; + # (GetCacheHintRegisterValue & 0xf) << 6; + ("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at? + # 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6; + ("u32", 0x260, "OutputRowStride"), + ("u32", 0x264, "OutputPlaneStride"), + ("u32", 0x268, "OutputDepthStride"), + ("u32", 0x26C, "OutputBatchStride"), + # 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000; + # uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031; + # (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18; + ("u8", 0x273, "OutputInterleave"), # i also have this at 0x211? ] ANE_Struct_Dict = {} for typ, num, nam in ANE_Struct: - styp = {"u32": "I", "u16": "H", "u8": "B"}[typ] - ANE_Struct_Dict[nam] = (styp, num) + styp = {"u32": "I", "u16": "H", "u8": "B"}[typ] + ANE_Struct_Dict[nam] = (styp, num) + class ANETensor: - def __init__(self, *shape): - self.shape = shape - self.dtype = np.float16 - self.sz = int(np.prod(shape)) - assert(self.sz <= 0x4000) - self.tt = libane.ANE_TensorCreate(self.sz, 1) - assert(self.tt is not None) + def __init__(self, *shape): + self.shape = shape + self.dtype = np.float16 + self.sz = int(np.prod(shape)) + assert self.sz <= 0x4000 + self.tt = libane.ANE_TensorCreate(self.sz, 1) + assert self.tt is not None + + def data(self): + data = libane.ANE_TensorData(self.tt) + assert data is not None + # print(hex(addressof(data.contents))) + buf = np.ctypeslib.as_array(data, shape=(self.sz,)) + ret = np.frombuffer(buf, dtype=self.dtype) + # print(ret.data) + return ret - def data(self): - data = libane.ANE_TensorData(self.tt) - assert(data is not None) - #print(hex(addressof(data.contents))) - buf = np.ctypeslib.as_array(data, shape=(self.sz,)) - ret = np.frombuffer(buf, dtype=self.dtype) - #print(ret.data) - return ret class ANE: - def __init__(self): - init_libane() - libane.ANE_Open() + def __init__(self): + init_libane() + libane.ANE_Open() - def compile(self, dat): - ret = libane.ANE_Compile(create_string_buffer(dat), len(dat)) - assert(ret is not None) - return ret + def compile(self, dat): + ret = libane.ANE_Compile(create_string_buffer(dat), len(dat)) + assert ret is not None + return ret - def run(self, prog, tin, tout, tweights=None): - libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0) + def run(self, prog, tin, tout, tweights=None): + libane.ANE_Run( + prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0 + ) - def tensor(self, shape): - return ANETensor(shape) + def tensor(self, shape): + return ANETensor(shape) - def unpack(self, dat): - dat = struct.unpack("Q"*(len(dat)//8), dat) - ret = {} - for k,v in aneregs: - by,bi,sz = v - bi += (by%8)*8 - by //= 8 - rv = (dat[by] >> bi) & ((1 << sz)-1) - ret[k] = rv - return ret + def unpack(self, dat): + dat = struct.unpack("Q" * (len(dat) // 8), dat) + ret = {} + for k, v in aneregs: + by, bi, sz = v + bi += (by % 8) * 8 + by //= 8 + rv = (dat[by] >> bi) & ((1 << sz) - 1) + ret[k] = rv + return ret - def pack(self, pk, dat): - dat = list(struct.unpack("Q"*(len(dat)//8), dat)) - for k,v in aneregs: - by,bi,sz = v - bi += (by%8)*8 - by //= 8 - dat[by] &= ~(((1 << sz)-1) << bi) - dat[by] |= pk[k] << bi - dat = struct.pack("Q"*len(dat), *dat) - return dat + def pack(self, pk, dat): + dat = list(struct.unpack("Q" * (len(dat) // 8), dat)) + for k, v in aneregs: + by, bi, sz = v + bi += (by % 8) * 8 + by //= 8 + dat[by] &= ~(((1 << sz) - 1) << bi) + dat[by] |= pk[k] << bi + dat = struct.pack("Q" * len(dat), *dat) + return dat - def debug(self, dat, mems=0): - add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c] - lens = [244, 60, 108, 68, 12, 16, 24] - ptr = 0x2b - ddat = dat[0:0x28] - for a, pm in zip(add, lens): - #assert pm == dat[ptr] - ddat += b"\x00" * (a-len(ddat)) - ddat += dat[ptr+1:ptr+1+pm+4] - ptr += pm+8 - ddat += b"\x00" * 0x100 - ret = collections.OrderedDict() - for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"): - lnn = ln.split(" = ") - if len(lnn) == 2: - ret[lnn[0]] = int(lnn[1]) - return ret + def debug(self, dat, mems=0): + add = [0x30, 0x1D4, 0x220, 0x29C, 0x2F0, 0x30C, 0x32C] + lens = [244, 60, 108, 68, 12, 16, 24] + ptr = 0x2B + ddat = dat[0:0x28] + for a, pm in zip(add, lens): + # assert pm == dat[ptr] + ddat += b"\x00" * (a - len(ddat)) + ddat += dat[ptr + 1 : ptr + 1 + pm + 4] + ptr += pm + 8 + ddat += b"\x00" * 0x100 + ret = collections.OrderedDict() + for ln in ( + libane.ANE_RegDebug(0, create_string_buffer(ddat), mems) + .decode("utf-8") + .strip() + .split("\n") + ): + lnn = ln.split(" = ") + if len(lnn) == 2: + ret[lnn[0]] = int(lnn[1]) + return ret - def filln(self, dat, nvdict, base=0x4000): - for n,v in nvdict.items(): - styp, num = ANE_Struct_Dict[n] - dat = self.fill(dat, [num], styp, v) - return dat + def filln(self, dat, nvdict, base=0x4000): + for n, v in nvdict.items(): + styp, num = ANE_Struct_Dict[n] + dat = self.fill(dat, [num], styp, v) + return dat + + def fill(self, dat, addrs, type, val, base=0x4000): + x = struct.pack(type, val) + for a in addrs: + dat[base + a : base + a + len(x)] = x + return dat - def fill(self, dat, addrs, type, val, base=0x4000): - x = struct.pack(type, val) - for a in addrs: - dat[base+a:base+a+len(x)] = x - return dat if __name__ == "__main__": - ane = ANE() + ane = ANE() - tin = ANETensor(16) - tout = ANETensor(16) + tin = ANETensor(16) + tout = ANETensor(16) - tind = tin.data() - toutd = tout.data() + tind = tin.data() + toutd = tout.data() - tind[0:4] = [-1,1,-2,2] - print("** before **") - print(tind) - print(toutd) + tind[0:4] = [-1, 1, -2, 2] + print("** before **") + print(tind) + print(toutd) - dat = open("../ops/relu.hwx", "rb").read() - md = dat[0x4000:0x4300] - dd = ane.unpack(md) - mdf = ane.pack(dd, md) - assert(md == mdf) - - comp = ane.compile(dat) - ret = ane.run(comp, tin, tout) - print("** after **") - print(tind) - print(toutd) + dat = open("../ops/relu.hwx", "rb").read() + md = dat[0x4000:0x4300] + dd = ane.unpack(md) + mdf = ane.pack(dd, md) + assert md == mdf + comp = ane.compile(dat) + ret = ane.run(comp, tin, tout) + print("** after **") + print(tind) + print(toutd) diff --git a/extra/accel/ane/lib/testconv.py b/extra/accel/ane/lib/testconv.py index 3b8542d58..bef5808a9 100755 --- a/extra/accel/ane/lib/testconv.py +++ b/extra/accel/ane/lib/testconv.py @@ -2,63 +2,64 @@ import time from ane import ANE, ANETensor + def benchmark(ane): - tin = ANETensor(512*0x20) - tout = ANETensor(512*0x20) - dat = open("../ops/gemm.hwx", "rb").read() - for k,v in ane.debug(dat[0x4000:0x4300], 16).items(): - print(k,v) - comp = ane.compile(dat) + tin = ANETensor(512 * 0x20) + tout = ANETensor(512 * 0x20) + dat = open("../ops/gemm.hwx", "rb").read() + for k, v in ane.debug(dat[0x4000:0x4300], 16).items(): + print(k, v) + comp = ane.compile(dat) - st = time.time() - for i in range(1000): - ret = ane.run(comp, tin, tout) - et = time.time() - ts = (et-st) - ops = 1000*512*512*2 + st = time.time() + for i in range(1000): + ret = ane.run(comp, tin, tout) + et = time.time() + ts = et - st + ops = 1000 * 512 * 512 * 2 - print("%.2f ms, %.2f gigaops/sec" % (ts*1000, ops*1e-9/ts)) + print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts)) if __name__ == "__main__": - ane = ANE() + ane = ANE() - # 0x20 per row - tin = ANETensor(0x60) - tout = ANETensor(0x60) - tw = ANETensor(0x60) + # 0x20 per row + tin = ANETensor(0x60) + tout = ANETensor(0x60) + tw = ANETensor(0x60) - tind = tin.data() - toutd = tout.data() - twd = tw.data() + tind = tin.data() + toutd = tout.data() + twd = tw.data() - #tind[0:4] = [-1,1,-2,2] - tind[0] = 1 - tind[0x20] = -2 - tind[0x40] = 3 + # tind[0:4] = [-1,1,-2,2] + tind[0] = 1 + tind[0x20] = -2 + tind[0x40] = 3 - # toutd[0] = \ - # tind[0] * twd[0] + \ - # tind[0x20] + twd[1] + \ - # tind[0x40] + twd[2] + # toutd[0] = \ + # tind[0] * twd[0] + \ + # tind[0x20] + twd[1] + \ + # tind[0x40] + twd[2] - twd[0] = 4 - twd[1] = 0x100 + twd[0] = 4 + twd[1] = 0x100 - twd[0x20] = 5 - twd[0x21] = 5 - twd[0x22] = 5 + twd[0x20] = 5 + twd[0x21] = 5 + twd[0x22] = 5 - twd[0x40] = 12 + twd[0x40] = 12 - print("** before **") - print(tind) - print(toutd) + print("** before **") + print(tind) + print(toutd) - #benchmark(ane) - #exit(0) + # benchmark(ane) + # exit(0) - """ + """ dat = list(open("../ops/sum.hwx", "rb").read()) dat = bytes(dat) for k,v in ane.debug(dat[0x4000:0x4300], 16).items(): @@ -67,25 +68,25 @@ if __name__ == "__main__": ret = ane.run(comp, tin, tout, tw) """ - datb = open("../ops/sum.hwx", "rb").read() - dat = open("../ops/conv.hwx", "rb").read() - dd = ane.unpack(dat[0x4000:0x4300]) - # use the 3rd arg as the weights - dd["aneTD.Header[9].KBase0"] = 6 - dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00 - #dd["aneRegs.L2.L2Cfg.InputReLU"] = 1 - #dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1 - #dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0 - #dd["aneRegs.L2.ResultBase.Addr"] = 0 - #dd["aneRegs.Common.ChCfg.InFmt"] = 1 - #dd["aneRegs.TileDMADst.Fmt.ZeroPadFirst"] = 0 - #dd["aneRegs.TileDMADst.DMAConfig.En"] = 0 - for k,v in dd.items(): - print(k,v) - dat = datb[:0x4000] + ane.pack(dd, dat[0x4000:0x4300]) + datb[0x4300:] - comp = ane.compile(dat) - ret = ane.run(comp, tin, tout, tw) + datb = open("../ops/sum.hwx", "rb").read() + dat = open("../ops/conv.hwx", "rb").read() + dd = ane.unpack(dat[0x4000:0x4300]) + # use the 3rd arg as the weights + dd["aneTD.Header[9].KBase0"] = 6 + dd["aneRegs.NE.PostScale.PostScale"] = 0x3C00 + # dd["aneRegs.L2.L2Cfg.InputReLU"] = 1 + # dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1 + # dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0 + # dd["aneRegs.L2.ResultBase.Addr"] = 0 + # dd["aneRegs.Common.ChCfg.InFmt"] = 1 + # dd["aneRegs.TileDMADst.Fmt.ZeroPadFirst"] = 0 + # dd["aneRegs.TileDMADst.DMAConfig.En"] = 0 + for k, v in dd.items(): + print(k, v) + dat = datb[:0x4000] + ane.pack(dd, dat[0x4000:0x4300]) + datb[0x4300:] + comp = ane.compile(dat) + ret = ane.run(comp, tin, tout, tw) - print("** after **") - print(tind) - print(toutd) + print("** after **") + print(tind) + print(toutd) diff --git a/extra/accel/ane/tinygrad/ops_ane.py b/extra/accel/ane/tinygrad/ops_ane.py index b9b792d9b..f5d3210f0 100644 --- a/extra/accel/ane/tinygrad/ops_ane.py +++ b/extra/accel/ane/tinygrad/ops_ane.py @@ -1,39 +1,52 @@ from functools import lru_cache from .tensor import Device, Function, register + @lru_cache def compile_wrapper(ane, dat): - return ane.compile(dat) + return ane.compile(dat) + def roundup(x, v): - return x + (v-x)%v + return x + (v - x) % v + @lru_cache def compile_relu(ane, sz): - dat = list(open("accel/ane/ops/relu.hwx", "rb").read()) - # TODO: make this all nice and once - # number of engines? (max 0x100) - l2_stride = max(0x100, roundup(sz*2, 0x10)) - # 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride - # 0x1f4, 0x1f8? - # 0x214 = L2.ResultBase.Addr - dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride) - stride = roundup(sz*2, 0x40) - dat = ane.filln(dat, { - "NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash - "InputWidth": sz, "OutputWidth": sz, - "InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride, - "OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride, - }) - return compile_wrapper(ane, bytes(dat)) + dat = list(open("accel/ane/ops/relu.hwx", "rb").read()) + # TODO: make this all nice and once + # number of engines? (max 0x100) + l2_stride = max(0x100, roundup(sz * 2, 0x10)) + # 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride + # 0x1f4, 0x1f8? + # 0x214 = L2.ResultBase.Addr + dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride) + stride = roundup(sz * 2, 0x40) + dat = ane.filln( + dat, + { + "NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash + "InputWidth": sz, + "OutputWidth": sz, + "InputRowStride": stride, + "InputPlaneStride": stride, + "InputDepthStride": stride, + "OutputRowStride": stride, + "OutputPlaneStride": stride, + "OutputDepthStride": stride, + }, + ) + return compile_wrapper(ane, bytes(dat)) + class ReLU(Function): - def forward(ctx, input): - ret = ctx.ane.tensor(input.shape) - ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret) - return ret + def forward(ctx, input): + ret = ctx.ane.tensor(input.shape) + ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret) + return ret - def backward(ctx, grad_output): - return 0 + def backward(ctx, grad_output): + return 0 -register('relu', ReLU, device=Device.ANE) + +register("relu", ReLU, device=Device.ANE) diff --git a/extra/accel/intel/benchmark_matmul.py b/extra/accel/intel/benchmark_matmul.py index 5999039de..169939671 100644 --- a/extra/accel/intel/benchmark_matmul.py +++ b/extra/accel/intel/benchmark_matmul.py @@ -31,19 +31,20 @@ for x in out.values(): x.realize() """ from openvino.runtime import Core + core = Core() devices = core.available_devices for device in devices: - device_name = core.get_property(device, "FULL_DEVICE_NAME") - print(f"{device}: {device_name}") + device_name = core.get_property(device, "FULL_DEVICE_NAME") + print(f"{device}: {device_name}") model = core.read_model(onnx_path) -compiled_model = core.compile_model(model, device_name='GPU.0') +compiled_model = core.compile_model(model, device_name="GPU.0") print(compiled_model) ireq = compiled_model.create_infer_request() for model_input in compiled_model.inputs: - tensor = ireq.get_tensor(model_input) - tensor.data[:] = 2 - print(tensor) + tensor = ireq.get_tensor(model_input) + tensor.data[:] = 2 + print(tensor) print("request") ireq.infer() ireq.infer() @@ -51,7 +52,7 @@ print("did one") REPS = 20 st = time.perf_counter() -for i in range(REPS): ireq.infer() +for i in range(REPS): + ireq.infer() et = time.perf_counter() - st print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS") - diff --git a/extra/archprobe.py b/extra/archprobe.py index fc2d3d207..cda2ee0e7 100644 --- a/extra/archprobe.py +++ b/extra/archprobe.py @@ -7,11 +7,14 @@ from tqdm import trange, tqdm from matplotlib import pyplot as plt tests = {} + + def register_test(fxn): - tests[fxn.__name__] = fxn + tests[fxn.__name__] = fxn + def warp_size2(nthread): - prg = """__kernel void warp_size2( + prg = """__kernel void warp_size2( __global float* src, __global int* dst, const int niter, @@ -24,20 +27,40 @@ def warp_size2(nthread): } dst[get_local_id(0)] = drain; }""" - src_buf = CLBuffer(1, dtypes.float32) - dst_buf = CLBuffer(1, dtypes.int32) - cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32]) - return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9 + src_buf = CLBuffer(1, dtypes.float32) + dst_buf = CLBuffer(1, dtypes.int32) + cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32]) + return ( + min( + [ + cl( + [nthread, 1024, 1], + [nthread, 1, 1], + src_buf, + dst_buf, + 10, + 3, + wait=True, + ) + for _ in range(5) + ] + ) + * 1e9 + ) + @register_test def test_warp_size(): - return [(nthread, warp_size2(nthread)) for nthread in trange(1,256)] + return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)] + def reg_count(nthread, ngrp, nreg): - reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]) - reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)]) - reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)]) - prg = f"""__kernel void reg_count( + reg_declr = "".join( + [f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)] + ) + reg_comp = "".join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)]) + reg_reduce = "".join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)]) + prg = f"""__kernel void reg_count( __global float* out_buf, __private const int niter ) {{ @@ -49,18 +72,31 @@ def reg_count(nthread, ngrp, nreg): i = i >> 31; {reg_reduce} }}""" - out_buf = CLBuffer(1, dtypes.float32) - cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32]) - return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9 + out_buf = CLBuffer(1, dtypes.float32) + cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32]) + return ( + min( + [ + cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) + for _ in range(10) + ] + ) + * 1e9 + ) + @register_test def test_reg_count(nthread=1, ngrp=1): - base = reg_count(nthread, ngrp, 1) - return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)] + base = reg_count(nthread, ngrp, 1) + return [ + (nreg, (reg_count(nthread, ngrp, nreg) - base) / nreg) + for nreg in trange(4, 513, 4) + ] + def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536): - ndata //= NCOMP*4 # ptr size - prg = f"""__kernel void buf_cache_hierarchy_pchase( + ndata //= NCOMP * 4 # ptr size + prg = f"""__kernel void buf_cache_hierarchy_pchase( __global int{str(NCOMP) if NCOMP > 1 else ''}* src, __global int* dst, const int niter @@ -71,49 +107,76 @@ def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536): }} *dst = idx; }}""" - idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32) - for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata - in_buf = CLBuffer.fromCPU(idx_buf) - out_buf = CLBuffer(1, dtypes.int32) - cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32]) - return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9 + idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32) + for i in range(ndata): + idx_buf[i * NCOMP] = (i + stride) % ndata + in_buf = CLBuffer.fromCPU(idx_buf) + out_buf = CLBuffer(1, dtypes.int32) + cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32]) + return ( + min( + [ + cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True) / steps + for _ in range(5) + ] + ) + * 1e9 + ) + @register_test def test_memory_latency(): - # requires cacheline < 16 - szs = [int(1.3**x) for x in range(20, 70)] - return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)] + # requires cacheline < 16 + szs = [int(1.3**x) for x in range(20, 70)] + return [ + (ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128 * 1024)) + for ndata in tqdm(szs) + ] + @register_test def test_cacheline_size(): - # TODO: this buffer must be at least 2x the L1 cache for this test to work - return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)] + # TODO: this buffer must be at least 2x the L1 cache for this test to work + return [ + (stride, buf_cache_hierarchy_pchase(4 * 65536, stride, steps=65536)) + for stride in trange(1, 64) + ] + def cl_read(sz, niter=1): - prg = f"""__kernel void copy( + prg = f"""__kernel void copy( __global float4* src, __global float* dst) {{ int gid = get_global_id(0); if (src[gid].x == 99+get_global_id(1)) *dst = 1; }}""" - in_buf = CLBuffer(sz//4, dtypes.float32) - out_buf = CLBuffer(1, dtypes.float32) - cl = CLProgram("copy", prg) - # NOTE: if nay of the niters form a local group, this is wrong - return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9 + in_buf = CLBuffer(sz // 4, dtypes.float32) + out_buf = CLBuffer(1, dtypes.float32) + cl = CLProgram("copy", prg) + # NOTE: if nay of the niters form a local group, this is wrong + return ( + min( + [ + cl([sz // 16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) + for _ in range(10) + ] + ) + * 1e9 + ) + @register_test def test_read_bandwidth(): - szs = list(range(128*1024, 20*1024*1024, 128*1024)) - NITER = 8 - base = cl_read(16, niter=NITER) - return [(sz, (sz*NITER)/(cl_read(sz, niter=NITER)-base)) for sz in tqdm(szs)] + szs = list(range(128 * 1024, 20 * 1024 * 1024, 128 * 1024)) + NITER = 8 + base = cl_read(16, niter=NITER) + return [(sz, (sz * NITER) / (cl_read(sz, niter=NITER) - base)) for sz in tqdm(szs)] def gflops(niter=4, nroll=4, ngroups=4096): - NCOMP = 8 - prg = f"""__kernel void gflops( + NCOMP = 8 + prg = f"""__kernel void gflops( __global float* out_buf ) {{ float{NCOMP} x = (float{NCOMP})({",".join(f"get_local_id(0)+{i}" for i in range(NCOMP))}); @@ -125,30 +188,37 @@ def gflops(niter=4, nroll=4, ngroups=4096): out_buf[get_global_id(0) >> 31] = {'+'.join(f"y.s{'0123456789abcdef'[i]}" for i in range(NCOMP))}; }}""" - out_buf = CLBuffer(1, dtypes.float32) - cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math") - FLOPS = NCOMP*2*2 * niter * nroll * ngroups * 32 - # NOTE: if nay of the niters form a local group, this is wrong - return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9) + out_buf = CLBuffer(1, dtypes.float32) + cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math") + FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32 + # NOTE: if nay of the niters form a local group, this is wrong + return FLOPS / ( + min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)]) + * 1e9 + ) + @register_test def test_gflops(): - return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)] + return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)] + if __name__ == "__main__": - cache = {} - #cache = pickle.load(open("/tmp/cache.pkl", "rb")) - #tests = {"test_cacheline_size": tests["test_cacheline_size"]} - plt.figure(figsize=(16, 9)) - for i,(k,test) in enumerate(tests.items()): - print(f"running {k}") - plt.subplot(2, (len(tests)+1)//2, i+1) - plt.title(k) - if k == "test_memory_latency": plt.xscale('log') - if k not in cache: cache[k] = test() - plt.plot(*zip(*cache[k])) - #pickle.dump(cache, open("/tmp/cache.pkl", "wb")) + cache = {} + # cache = pickle.load(open("/tmp/cache.pkl", "rb")) + # tests = {"test_cacheline_size": tests["test_cacheline_size"]} + plt.figure(figsize=(16, 9)) + for i, (k, test) in enumerate(tests.items()): + print(f"running {k}") + plt.subplot(2, (len(tests) + 1) // 2, i + 1) + plt.title(k) + if k == "test_memory_latency": + plt.xscale("log") + if k not in cache: + cache[k] = test() + plt.plot(*zip(*cache[k])) + # pickle.dump(cache, open("/tmp/cache.pkl", "wb")) - plt.tight_layout(pad=0.5) - plt.savefig("/tmp/results.png") - plt.show() + plt.tight_layout(pad=0.5) + plt.savefig("/tmp/results.png") + plt.show() diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index f6f0289a3..1aa400e78 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -1,188 +1,427 @@ -from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast +from typing import ( + Tuple, + List, + NamedTuple, + Any, + Dict, + Optional, + Union, + DefaultDict, + cast, +) from tinygrad.codegen.linearizer import UOps, MemOp, UOp from tinygrad.ops import BinaryOps, UnaryOps from tinygrad.helpers import DType, dtypes, DEBUG -from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode +from tinygrad.shape.symbolic import ( + Variable, + NumNode, + MulNode, + DivNode, + ModNode, + LtNode, + SumNode, + AndNode, +) import functools import math from collections import defaultdict -_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h', - dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'} +_type_to_letter = { + dtypes.float32: "f", + dtypes.bool: "p", + dtypes.int32: "i", + dtypes.int64: "a", + dtypes.uint32: "u", + dtypes.uint64: "b", + dtypes.float.vec(4): "x", + dtypes.uint8: "uc", + dtypes.float16: "h", + dtypes.int8: "c", + dtypes.uint16: "us", + dtypes.float64: "d", +} + class Register(NamedTuple): - nm:str - dtype:DType - scalar:bool - off:Optional[int] = None - def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}" - def subregs(self): - if self.dtype == dtypes.float.vec(4): - return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] - return [] + nm: str + dtype: DType + scalar: bool + off: Optional[int] = None + + def __repr__(self): + return self.nm if self.off is None else f"{self.nm}:{self.off}" + + def subregs(self): + if self.dtype == dtypes.float.vec(4): + return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] + return [] + class AssemblyInstruction(NamedTuple): - op: UOps - out: Optional[Register] - vin: List[Union[Register, int, float]] - arg: Any = None + op: UOps + out: Optional[Register] + vin: List[Union[Register, int, float]] + arg: Any = None + # warp size of 32, s registers are shared across the warp, v are 32-wide vectors class AssemblyLanguage: - supports_load3: bool = False - sin_is_sin2pi: bool = False - no_div: bool = False - #TODO: these should be global vars - cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) - tor: Dict[Any, Register] = {} - ins: List[AssemblyInstruction] = [] + supports_load3: bool = False + sin_is_sin2pi: bool = False + no_div: bool = False + # TODO: these should be global vars + cnts: DefaultDict[Tuple[DType, bool], int] = defaultdict(int) + tor: Dict[Any, Register] = {} + ins: List[AssemblyInstruction] = [] - def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]] - def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register: - self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar) - if dtype == dtypes.float.vec(4): - for off in range(4): - self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off) - self.cnts[(dtype, scalar)] += 1 - return ret + def type_to_letter(self, x): + return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]] - def render_numnode(self, b) -> Register: - key = ("num", b) - if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) - return self.tor[key] + def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register: + self.tor[tok] = ret = Register( + f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", + dtype, + scalar, + ) + if dtype == dtypes.float.vec(4): + for off in range(4): + self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off) + self.cnts[(dtype, scalar)] += 1 + return ret - def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: - key = (op, a, b) - if key not in self.tor: - #if not isinstance(b, Register): b = render_numnode(b) - self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) - return self.tor[key] + def render_numnode(self, b) -> Register: + key = ("num", b) + if key not in self.tor: + self.ins.append( + AssemblyInstruction( + UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b + ) + ) + return self.tor[key] - def render_cast(self, a:Register, new_dtype:DType) -> Register: - if a.dtype == new_dtype: return a - key = (a, new_dtype) - if key not in self.tor: - self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])) - return self.tor[key] + def render_alu( + self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32 + ) -> Register: + key = (op, a, b) + if key not in self.tor: + # if not isinstance(b, Register): b = render_numnode(b) + self.ins.append( + AssemblyInstruction( + UOps.ALU, + self.newreg( + key, + dtype=dtype, + scalar=a.scalar and (not isinstance(b, Register) or b.scalar), + ), + [a, b], + op, + ) + ) + return self.tor[key] - render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), - MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), - DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), - ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), - LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } + def render_cast(self, a: Register, new_dtype: DType) -> Register: + if a.dtype == new_dtype: + return a + key = (a, new_dtype) + if key not in self.tor: + self.ins.append( + AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]) + ) + return self.tor[key] - def addr_w_offset(self, args): - assert isinstance(args, MemOp) - idx = args.idx*args.memory_dtype.itemsize - off = 0 # TODO: should this be None? - if isinstance(idx, SumNode): - nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] - if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? - idx -= nums[0] - off = cast(int, nums[0]) - reg = idx.render(self.render_ops, self) - if self.supports_load3: - if reg.scalar: - new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype) - self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) - reg = new_reg - return self.tor[args.name], reg, off - reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64) - return reg, None, off + render_ops: Any = { + Variable: lambda self, ops, ctx: ctx.tor[self], + NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), + MulNode: lambda self, ops, ctx: ctx.render_alu( + BinaryOps.MUL, self.a.render(ops, ctx), self.b + ), + DivNode: lambda self, ops, ctx: ctx.render_alu( + BinaryOps.DIV, self.a.render(ops, ctx), self.b + ), + ModNode: lambda self, ops, ctx: ctx.render_alu( + BinaryOps.MOD, self.a.render(ops, ctx), self.b + ), + LtNode: lambda self, ops, ctx: ctx.render_alu( + BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool + ), + SumNode: lambda self, ops, ctx: functools.reduce( + lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)), + self.nodes[1:], + self.nodes[0].render(ops, ctx), + ), + AndNode: lambda self, ops, ctx: functools.reduce( + lambda a, b: ctx.render_alu( + BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool + ), + self.nodes[1:], + self.nodes[0].render(ops, ctx), + ), + } -def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): - #TODO: Do not use clear() - lang.ins.clear() - lang.tor.clear() - lang.cnts.clear() - buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL} - global_size, local_size = [], [] - skipload_branch = 0 - lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] - for u in uops: - uop,dtype,vin,args,_ = u - if uop == UOps.DEFINE_LOCAL: - lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) - elif uop == UOps.LOOP: - if args[1] == "global": - for i,var in enumerate(args[0]): - global_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) - elif args[1] == "local": - for i,var in enumerate(args[0]): - local_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) - else: - for var in args[0]: - if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) - elif uop == UOps.ENDLOOP: - if args[1] not in ["global", "local", "global+local"]: - for var in reversed(args[0]): - if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) - pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool) - lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) - elif args[1] == "global+local": - for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) - elif args[1] == 'local': - for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}"))) - elif uop == UOps.CAST: - # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies - out = lang.newreg(u, dtype) - for i,sr in enumerate(out.subregs()): - lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) - elif uop == UOps.ALU: - out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u] - # this is the only thing that can violate SSA - if args in [BinaryOps.CMPLT]: - pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool) - lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args)) - lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) - elif args == BinaryOps.DIV and lang.no_div: - tmp = lang.newreg((u, "rcp")) - lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) - elif args == UnaryOps.SIN and lang.sin_is_sin2pi: - tmp = lang.newreg((u, "2pi")) - lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) - else: - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)) - elif uop == UOps.DEFINE_ACC: - reg = lang.newreg(u, dtype=dtype) - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args)) - elif uop == UOps.SPECIAL: - lang.tor[u] = lang.tor[args] - elif uop == UOps.CONST: - lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)) - elif uop == UOps.LOAD: - idx, treg, off = lang.addr_w_offset(args) - reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) - if args.valid.min == 0: - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) - if args.valid.max == 1: - pred = args.valid.render(lang.render_ops, lang) - lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - if args.valid.max == 1: - # NOTE: you can't compute the index in here, because it assumes it's all available later - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) - if args.valid.min == 0 and args.valid.max == 1: - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) - skipload_branch += 1 - elif uop == UOps.STORE: - if args is None: - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP)) - else: - idx, treg, off = lang.addr_w_offset(args) - lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) + def addr_w_offset(self, args): + assert isinstance(args, MemOp) + idx = args.idx * args.memory_dtype.itemsize + off = 0 # TODO: should this be None? + if isinstance(idx, SumNode): + nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] + if ( + nums and nums[0] < 4096 and (idx - nums[0]).min >= 0 + ): # TODO: different for each GPU? + idx -= nums[0] + off = cast(int, nums[0]) + reg = idx.render(self.render_ops, self) + if self.supports_load3: + if reg.scalar: + new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype) + self.ins.append( + AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP) + ) + reg = new_reg + return self.tor[args.name], reg, off + reg = self.render_alu( + BinaryOps.ADD, + self.render_cast(reg, dtypes.uint64), + self.tor[args.name], + dtype=dtypes.uint64, + ) + return reg, None, off - if DEBUG >= 4: - for tins in lang.ins: print(tins) - return global_size, local_size + +def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]): + # TODO: Do not use clear() + lang.ins.clear() + lang.tor.clear() + lang.cnts.clear() + buf_to_dtype = { + args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL + } + global_size, local_size = [], [] + skipload_branch = 0 + lang.ins += [ + AssemblyInstruction( + UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf + ) + for buf in buf_to_dtype + ] + for u in uops: + uop, dtype, vin, args, _ = u + if uop == UOps.DEFINE_LOCAL: + lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) + lang.ins.append( + AssemblyInstruction( + UOps.ALU, + lang.newreg(args[0], dtype=dtypes.uint64), + [args[0]], + UnaryOps.NOOP, + ) + ) + elif uop == UOps.LOOP: + if args[1] == "global": + for i, var in enumerate(args[0]): + global_size.append(var.max + 1) + lang.ins.append( + AssemblyInstruction( + UOps.SPECIAL, + lang.newreg(var, dtype=dtypes.int32), + [], + f"gid{len(args[0])-1-i}", + ) + ) + elif args[1] == "local": + for i, var in enumerate(args[0]): + local_size.append(var.max + 1) + lang.ins.append( + AssemblyInstruction( + UOps.SPECIAL, + lang.newreg(var, dtype=dtypes.int32), + [], + f"lid{len(args[0])-1-i}", + ) + ) + else: + for var in args[0]: + if not isinstance( + var, NumNode + ): # TODO: why is this coming through? + lang.ins.append( + AssemblyInstruction( + UOps.LOAD, + lang.newreg(var, dtype=dtypes.int32, scalar=True), + [], + 0, + ) + ) + lang.ins.append( + AssemblyInstruction( + UOps.LABEL, None, [], "$loop_" + var.expr + ) + ) + elif uop == UOps.ENDLOOP: + if args[1] not in ["global", "local", "global+local"]: + for var in reversed(args[0]): + if not isinstance( + var, NumNode + ): # TODO: why is this coming through? + lang.ins.append( + AssemblyInstruction( + UOps.ALU, + lang.tor[var], + [lang.tor[var], 1], + BinaryOps.ADD, + ) + ) + pred = lang.render_alu( + BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool + ) + lang.ins.append( + AssemblyInstruction( + UOps.COND_BRANCH, + None, + [pred], + ("$loop_" + var.expr, True), + ) + ) + elif args[1] == "global+local": + for i, var in enumerate(reversed(args[0])): + lang.ins.append( + AssemblyInstruction( + UOps.ENDLOOP, + None, + [lang.tor[var]], + (var.max + 1, f"gid{i}"), + ) + ) + elif args[1] == "local": + for i, var in enumerate(reversed(args[0])): + lang.ins.append( + AssemblyInstruction( + UOps.ENDLOOP, + None, + [lang.tor[var]], + (var.max + 1, f"lid{i}"), + ) + ) + elif uop == UOps.CAST: + # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies + out = lang.newreg(u, dtype) + for i, sr in enumerate(out.subregs()): + lang.ins.append( + AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP) + ) + elif uop == UOps.ALU: + out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u] + # this is the only thing that can violate SSA + if args in [BinaryOps.CMPLT]: + pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool) + lang.ins.append( + AssemblyInstruction( + UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args + ) + ) + lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) + elif args == BinaryOps.DIV and lang.no_div: + tmp = lang.newreg((u, "rcp")) + lang.ins.append( + AssemblyInstruction( + UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP + ) + ) + lang.ins.append( + AssemblyInstruction( + UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL + ) + ) + elif args == UnaryOps.SIN and lang.sin_is_sin2pi: + tmp = lang.newreg((u, "2pi")) + lang.ins.append( + AssemblyInstruction( + UOps.ALU, + tmp, + [lang.tor[vin[0]], 1 / (math.pi * 2)], + BinaryOps.MUL, + ) + ) + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) + else: + lang.ins.append( + AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args) + ) + elif uop == UOps.DEFINE_ACC: + reg = lang.newreg(u, dtype=dtype) + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args)) + elif uop == UOps.SPECIAL: + lang.tor[u] = lang.tor[args] + elif uop == UOps.CONST: + lang.ins.append( + AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args) + ) + elif uop == UOps.LOAD: + idx, treg, off = lang.addr_w_offset(args) + reg = lang.newreg( + u, + dtype=dtype, + scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)), + ) + if args.valid.min == 0: + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) + if args.valid.max == 1: + pred = args.valid.render(lang.render_ops, lang) + lang.ins.append( + AssemblyInstruction( + UOps.COND_BRANCH, + None, + [pred], + (f"$skipload_{skipload_branch}", False), + ) + ) + if args.valid.max == 1: + # NOTE: you can't compute the index in here, because it assumes it's all available later + lang.ins.append( + AssemblyInstruction( + UOps.LOAD, + reg, + [idx] + ([treg] if treg is not None else []), + ( + off, + "global" if not args.local else "shared", + args.memory_dtype + if args.memory_dtype != dtypes.float + else None, + ), + ) + ) + if args.valid.min == 0 and args.valid.max == 1: + lang.ins.append( + AssemblyInstruction( + UOps.LABEL, None, [], f"$skipload_{skipload_branch}" + ) + ) + skipload_branch += 1 + elif uop == UOps.STORE: + if args is None: + lang.ins.append( + AssemblyInstruction( + UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP + ) + ) + else: + idx, treg, off = lang.addr_w_offset(args) + lang.ins.append( + AssemblyInstruction( + UOps.STORE, + None, + [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), + ( + off, + "global" if not args.local else "shared", + args.memory_dtype + if args.memory_dtype != dtypes.float + else None, + ), + ) + ) + + if DEBUG >= 4: + for tins in lang.ins: + print(tins) + return global_size, local_size diff --git a/extra/assembly/assembly_arm64.py b/extra/assembly/assembly_arm64.py index 6a4dc48fb..44b38d63c 100644 --- a/extra/assembly/assembly_arm64.py +++ b/extra/assembly/assembly_arm64.py @@ -6,171 +6,268 @@ from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes, CI from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) + +def float_to_hex(x): + return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1]) + + def compute_offsets(total): - quotient, remainder = divmod(total, 4096) - return [4096]*quotient + [remainder] if remainder else [4096]*quotient + quotient, remainder = divmod(total, 4096) + return [4096] * quotient + [remainder] if remainder else [4096] * quotient -#NOTE: Darwin needs names to start with a "_" -def get_name(name): return ('_' if system() == 'Darwin' else '') + name -class ARM64Language(AssemblyLanguage): pass +# NOTE: Darwin needs names to start with a "_" +def get_name(name): + return ("_" if system() == "Darwin" else "") + name + + +class ARM64Language(AssemblyLanguage): + pass + def specialize_to_arm64(fn_nm, asm): - var_size = 16 - prev_uop:Optional[UOps] = None - ins = [] - x_regs = ['x' + str(i) for i in reversed(range(12))] - s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16] - type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", - BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", - UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", - UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), - TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} + var_size = 16 + prev_uop: Optional[UOps] = None + ins = [] + x_regs = ["x" + str(i) for i in reversed(range(12))] + s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16] + type_to_reg = { + dtypes.double: "d", + dtypes.half: "h", + dtypes.float32: "s", + dtypes.bool: "w", + dtypes.int8: "w", + dtypes.int32: "w", + dtypes.int64: "x", + dtypes.uint8: "w", + dtypes.uint32: "w", + dtypes.uint64: "x", + } + alu = { + BinaryOps.ADD: "add", + BinaryOps.SUB: "sub", + BinaryOps.MUL: "mul", + BinaryOps.DIV: "div", + BinaryOps.MAX: "max", + BinaryOps.MOD: "", + BinaryOps.CMPLT: "subs", + UnaryOps.NOOP: "mov", + UnaryOps.NEG: "neg", + UnaryOps.SIN: "bl " + get_name("sinf"), + UnaryOps.LOG2: "bl " + get_name("log2f"), + UnaryOps.EXP2: "bl " + get_name("exp2f"), + UnaryOps.SQRT: "bl " + get_name("sqrtf"), + TernaryOps.MULACC: "madd", + TernaryOps.WHERE: "fcsel", + } - def mov_imm(value, reg): - # Manually move value into reg if value can't fit - if value.__class__ is not float and abs(value) > abs(65535): - ins.append(f"movz w15, #{value & 0xffff}") - ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") - ins.append(f"sxtw {reg}, w15") - elif reg[0] == 's': - ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") - ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") - ins.append("str x15, [sp, 16]") - ins.append(f"ldr {reg}, [sp, 16]") - else: - ins.append(f"mov {reg}, #{value}") - - # Get variables intervals - live_range:Dict[str, List[int]] = {} - for i, (uop, out, vin, arg) in enumerate(asm): - for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): - live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] - - mem_vars:Dict[str, int] = {} - rtor:Dict[str, str] = {} - def allocate_regs(mvars): - nonlocal var_size - for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]: - available_regs = s_regs if dtypes.is_float(v[1]) else x_regs - #NOTE: Very simple spill, everything that don't fit in regs goes to mem - if not available_regs: - # ARM needs the stack 16-byte aligned - var_size += 16 - available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12') - mem_vars[v.nm] = var_size - rtor[v.nm] = available_regs.pop() - - temp_floats = ['s0', 's1', 's2'] - temp_ints = ['x12', 'x13', 'x16'] - for i, (uop, out, vin, arg) in enumerate(asm): - # Clear regs out of interval - for var, reg in list(rtor.items()): - available_regs = s_regs if reg[0] == 's' else x_regs - if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]: - available_regs.append(rtor.pop(var)) - # Assign a registers to the variables using live ranges. - allocate_regs([out] + vin) - # Assign temp regs to vin and load them before direct use - for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): - rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] - # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 - ins.append(f"mov x15, {mem_vars[v.nm]}") - ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") - - if uop == UOps.SPECIAL: - if arg.startswith('data'): - # data 8 to n into the stack - if int(arg[4:]) >= 8: - ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]") - ins.append(f"mov {rtor[out.nm]}, x15") - else: - ins.append(f"mov {rtor[out.nm]}, #0") - ins.append(f"loop_{arg}:") - elif uop == UOps.CAST: - if arg == BinaryOps.CMPLT: - if rtor[out.nm][0] == 's': - mov_imm(0.0, 's0') - mov_imm(1.0, 's1') - ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") - if rtor[out.nm][0] == 'x': - mov_imm(0, 'x14') - mov_imm(1, 'x15') - ins.append(f"csel {rtor[out.nm]}, x15, x14, lt") - else: - ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") - elif uop == UOps.ALU: - if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') - if arg == BinaryOps.MUL and out.dtype == dtypes.bool: - ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") - elif arg == TernaryOps.WHERE: - ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0") - ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") - elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: - #NOTE: Not a real instruction, use to emulate a ext call in unicorn - if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") + def mov_imm(value, reg): + # Manually move value into reg if value can't fit + if value.__class__ is not float and abs(value) > abs(65535): + ins.append(f"movz w15, #{value & 0xffff}") + ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") + ins.append(f"sxtw {reg}, w15") + elif reg[0] == "s": + ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") + ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") + ins.append("str x15, [sp, 16]") + ins.append(f"ldr {reg}, [sp, 16]") else: - save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] - ins.append(f"sub sp, sp, #{(len(save_regs))*16}") - # Save the registers before they are cleared by func call - for i,k in enumerate(save_regs,1): - ins.append(f"str {rtor[k]}, [sp, #{16*i}]") - ins.append("stp x29, x30, [sp, #0]!") - ins.append("mov x29, sp") - ins.append(f"fmov s0, {rtor[vin[0].nm]}") - ins.append(alu[arg]) - ins.append(f"fmov {rtor[out.nm]}, s0") - ins.append("mov sp, x29") - ins.append("ldp x29, x30, [sp], #0") - for i,k in enumerate(save_regs,1): - ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") - ins.append(f"add sp, sp, #{len(save_regs)*16}") - elif arg == BinaryOps.CMPLT: - ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") - elif arg == BinaryOps.MOD: - rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm] - ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}") - ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}") - else: - ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") - elif uop == UOps.LOAD: - if arg.__class__ in (int, float): - mov_imm(arg, rtor[out.nm]) - else: - #NOTE: if need casting load var in s/h0 or x/w12 temp regs - reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm] - mov_imm(arg[0], "x15") - ins.append(f"add x15, {rtor[vin[0].nm]}, x15") - ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") - if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}") - elif uop == UOps.STORE: - #NOTE: if need casting load var in s/h0 or x/w12 temp regs - reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) - if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}") - ins.append(f"mov x15, #{arg[0]}") - ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]") - elif uop == UOps.COND_BRANCH: - #TODO: this is a hack it shouldn't always be a cmp before a cond branch? - if prev_uop == UOps.LOAD: - ins.append(f"cmp {rtor[vin[0].nm]}, #0") - ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") - elif uop == UOps.LABEL: - ins.append(f"{arg[1:]}:") - elif uop == UOps.ENDLOOP: - mov_imm(arg[0], "x15") - ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") - ins.append(f"cmp {rtor[vin[0].nm]}, x15") - ins.append(f"b.lt loop_{arg[1]}") - prev_uop = uop - # store regs into memory if needed - if out is not None and out.nm in mem_vars: - ins.append(f"mov x15, {mem_vars[out.nm]}") - ins.append(f"str {rtor[out.nm]}, [sp, x15]") - return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) + ins.append(f"mov {reg}, #{value}") -def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: - lang = ARM64Language() - global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) - return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True \ No newline at end of file + # Get variables intervals + live_range: Dict[str, List[int]] = {} + for i, (uop, out, vin, arg) in enumerate(asm): + for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]: + live_range[var.nm] = ( + [i, i] if var.nm not in live_range else [live_range[var.nm][0], i] + ) + + mem_vars: Dict[str, int] = {} + rtor: Dict[str, str] = {} + + def allocate_regs(mvars): + nonlocal var_size + for v in [ + v + for v in mvars + if v is not None and v.__class__ is not int and v.nm not in rtor + ]: + available_regs = s_regs if dtypes.is_float(v[1]) else x_regs + # NOTE: Very simple spill, everything that don't fit in regs goes to mem + if not available_regs: + # ARM needs the stack 16-byte aligned + var_size += 16 + available_regs.append("s0" if dtypes.is_float(out[1]) else "x12") + mem_vars[v.nm] = var_size + rtor[v.nm] = available_regs.pop() + + temp_floats = ["s0", "s1", "s2"] + temp_ints = ["x12", "x13", "x16"] + for i, (uop, out, vin, arg) in enumerate(asm): + # Clear regs out of interval + for var, reg in list(rtor.items()): + available_regs = s_regs if reg[0] == "s" else x_regs + if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]: + available_regs.append(rtor.pop(var)) + # Assign a registers to the variables using live ranges. + allocate_regs([out] + vin) + # Assign temp regs to vin and load them before direct use + for i, v in enumerate( + [v for v in vin if v.__class__ is not int and v.nm in mem_vars] + ): + rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] + # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 + ins.append(f"mov x15, {mem_vars[v.nm]}") + ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") + + if uop == UOps.SPECIAL: + if arg.startswith("data"): + # data 8 to n into the stack + if int(arg[4:]) >= 8: + ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]") + ins.append(f"mov {rtor[out.nm]}, x15") + else: + ins.append(f"mov {rtor[out.nm]}, #0") + ins.append(f"loop_{arg}:") + elif uop == UOps.CAST: + if arg == BinaryOps.CMPLT: + if rtor[out.nm][0] == "s": + mov_imm(0.0, "s0") + mov_imm(1.0, "s1") + ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") + if rtor[out.nm][0] == "x": + mov_imm(0, "x14") + mov_imm(1, "x15") + ins.append(f"csel {rtor[out.nm]}, x15, x14, lt") + else: + ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") + elif uop == UOps.ALU: + if len(vin) == 2 and vin[1].__class__ is int: + mov_imm(vin[1], "x15") + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append( + f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" + ) + elif arg == TernaryOps.WHERE: + ins.append( + f"fcmp {rtor[vin[0].nm]}, #0.0" + if rtor[vin[0].nm][0] == "s" + else f"cmp {rtor[vin[0].nm]}, #0" + ) + ins.append( + f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne" + ) + elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: + # NOTE: Not a real instruction, use to emulate a ext call in unicorn + if CI: + ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") + else: + save_regs = [ + k for k in rtor.keys() if k != out.nm and k not in mem_vars + ] + ins.append(f"sub sp, sp, #{(len(save_regs))*16}") + # Save the registers before they are cleared by func call + for i, k in enumerate(save_regs, 1): + ins.append(f"str {rtor[k]}, [sp, #{16*i}]") + ins.append("stp x29, x30, [sp, #0]!") + ins.append("mov x29, sp") + ins.append(f"fmov s0, {rtor[vin[0].nm]}") + ins.append(alu[arg]) + ins.append(f"fmov {rtor[out.nm]}, s0") + ins.append("mov sp, x29") + ins.append("ldp x29, x30, [sp], #0") + for i, k in enumerate(save_regs, 1): + ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") + ins.append(f"add sp, sp, #{len(save_regs)*16}") + elif arg == BinaryOps.CMPLT: + ins.append( + f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" + if not dtypes.is_float(vin[0][1]) + else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}" + ) + elif arg == BinaryOps.MOD: + rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm] + ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}") + ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}") + else: + ins.append( + f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" + ) + elif uop == UOps.LOAD: + if arg.__class__ in (int, float): + mov_imm(arg, rtor[out.nm]) + else: + # NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_in = ( + type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12") + if arg[2] is not None + else rtor[out.nm] + ) + mov_imm(arg[0], "x15") + ins.append(f"add x15, {rtor[vin[0].nm]}, x15") + ins.append( + f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]" + ) + if arg[2] is not None: + ins.append( + f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}" + ) + elif uop == UOps.STORE: + # NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_out = ( + type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12") + if arg[2] is not None + else rtor[vin[1].nm] + ) + if arg[2] is not None: + ins.append( + f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}" + ) + ins.append(f"mov x15, #{arg[0]}") + ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]") + elif uop == UOps.COND_BRANCH: + # TODO: this is a hack it shouldn't always be a cmp before a cond branch? + if prev_uop == UOps.LOAD: + ins.append(f"cmp {rtor[vin[0].nm]}, #0") + ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") + elif uop == UOps.LABEL: + ins.append(f"{arg[1:]}:") + elif uop == UOps.ENDLOOP: + mov_imm(arg[0], "x15") + ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") + ins.append(f"cmp {rtor[vin[0].nm]}, x15") + ins.append(f"b.lt loop_{arg[1]}") + prev_uop = uop + # store regs into memory if needed + if out is not None and out.nm in mem_vars: + ins.append(f"mov x15, {mem_vars[out.nm]}") + ins.append(f"str {rtor[out.nm]}, [sp, x15]") + return "\n".join( + [ + f"//varsize {var_size}", + ".arch armv8-a", + ".text", + f".global {get_name(fn_nm)}", + ".p2align 2", + f"{get_name(fn_nm)}:", + "mov x17, sp", + ] + + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)] + + ins + + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] + + ["ret", "\n"] + ) + + +def uops_to_arm64_asm( + fn_nm: str, uops: List[UOp] +) -> Tuple[str, List[int], List[int], bool]: + lang = ARM64Language() + global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) + return ( + specialize_to_arm64(fn_nm, lang.ins), + global_size[::-1], + local_size[::-1], + True, + ) diff --git a/extra/assembly/assembly_ptx.py b/extra/assembly/assembly_ptx.py index 69e610527..25ae892c4 100644 --- a/extra/assembly/assembly_ptx.py +++ b/extra/assembly/assembly_ptx.py @@ -6,100 +6,211 @@ from tinygrad.helpers import dtypes from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps from tinygrad.runtime.ops_cuda import arch -dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"} -def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) +dtype_to_nvtype = { + dtypes.float32: "f32", + dtypes.float16: "f16", + dtypes.int64: "s64", + dtypes.int32: "s32", + dtypes.int8: "s8", + dtypes.bool: "pred", + dtypes.uint64: "u64", + dtypes.uint32: "u32", + dtypes.uint16: "u16", + dtypes.uint8: "u8", + "bits16": "b16", + dtypes.float64: "f64", +} + + +def float_to_hex(x): + return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1]) + + +def ptx_needs_cast(dest_dtype, src_dtype): + return ( + dtypes.is_float(dest_dtype) + and dtypes.is_int(src_dtype) + or dtypes.is_int(dest_dtype) + and dtypes.is_float(src_dtype) + or ( + dtypes.is_float(src_dtype) + and dtypes.is_float(dest_dtype) + and dest_dtype.itemsize != src_dtype.itemsize + ) + ) -def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize) def render_cast(ins, inp, out): - if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)): - ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};") - elif out.dtype == dtypes.bool: - if inp.dtype == dtypes.bool: - ins.append(f"mov.pred {out}, {inp};") + if inp.dtype == dtypes.bool and ( + dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype) + ): + ins.append( + f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};" + ) + elif out.dtype == dtypes.bool: + if inp.dtype == dtypes.bool: + ins.append(f"mov.pred {out}, {inp};") + else: + ins.append( + f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};" + ) else: - ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") - else: - round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else '' - ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};") + round_mod = ( + ".rzi" + if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) + else ".rz" + if dtypes.is_float(out.dtype) + and ( + dtypes.is_int(inp.dtype) + or dtypes.is_float(inp.dtype) + and inp.dtype.itemsize > out.dtype.itemsize + ) + else "" + ) + ins.append( + f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};" + ) + # https://docs.nvidia.com/cuda/parallel-thread-execution/# + class PTXLanguage(AssemblyLanguage): - supports_constant_folding: bool = True + supports_constant_folding: bool = True + def specialize_to_ptx(lang, function_name): - param_cnt = 0 - ins = [] - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", - BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx", - UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg", - UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", - TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} - for uop, out, vin, arg in lang.ins: - if uop == UOps.ENDLOOP: - ins.append("bar.sync 0;") - elif uop == UOps.DEFINE_LOCAL: - ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") - elif uop == UOps.SPECIAL: - if arg.startswith('data'): - param_cnt += 1 - ins.append(f"ld.param.u64 {out}, [{arg}];") - # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to? - # ins.append(f"cvta.to.global.u64 {out}, {out};") - elif arg.startswith('gid'): - ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") - elif arg.startswith('lid'): - ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") - elif uop == UOps.ALU: - if arg == BinaryOps.MUL and out.dtype == dtypes.bool: - ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") - else: - otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype - if arg == TernaryOps.WHERE: - if vin[0].dtype == dtypes.bool: - reg = vin[0] - else: - reg = lang.newreg((vin[0], 'bool'), dtypes.bool) - ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") - vin = vin[1:] + [reg] - ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") - elif uop == UOps.LOAD: - if arg.__class__ in (int, float): - ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};") - elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): - dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2]) - reg = lang.newreg((out, dt[0]), dtype=dt[1]) - ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - render_cast(ins, reg, out) - else: - ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - elif uop == UOps.STORE: - if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool: - if arg[2] == dtypes.bool != vin[1].dtype: - prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool) - render_cast(ins, vin[1], prereg) - else: prereg = vin[1] - reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]) - render_cast(ins, prereg, reg) - ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};") - else: - ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") - elif uop == UOps.CAST: - render_cast(ins, vin[0], out) - elif uop == UOps.LABEL: - ins.append(f"{arg}:") - elif uop == UOps.COND_BRANCH: - ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") + param_cnt = 0 + ins = [] + alu = { + BinaryOps.ADD: "add", + BinaryOps.SUB: "sub", + BinaryOps.MUL: "mul", + BinaryOps.DIV: "div", + BinaryOps.MAX: "max", + BinaryOps.MOD: "rem", + BinaryOps.CMPLT: "setp.lt", + UnaryOps.SQRT: "sqrt.approx", + UnaryOps.NOOP: "mov", + UnaryOps.NEG: "neg", + UnaryOps.SIN: "sin.approx", + UnaryOps.LOG2: "lg2.approx", + UnaryOps.EXP2: "ex2.approx.ftz", + TernaryOps.MULACC: "fma.rn", + TernaryOps.WHERE: "selp", + } + for uop, out, vin, arg in lang.ins: + if uop == UOps.ENDLOOP: + ins.append("bar.sync 0;") + elif uop == UOps.DEFINE_LOCAL: + ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") + elif uop == UOps.SPECIAL: + if arg.startswith("data"): + param_cnt += 1 + ins.append(f"ld.param.u64 {out}, [{arg}];") + # TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to? + # ins.append(f"cvta.to.global.u64 {out}, {out};") + elif arg.startswith("gid"): + ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") + elif arg.startswith("lid"): + ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") + elif uop == UOps.ALU: + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") + else: + otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype + if arg == TernaryOps.WHERE: + if vin[0].dtype == dtypes.bool: + reg = vin[0] + else: + reg = lang.newreg((vin[0], "bool"), dtypes.bool) + ins.append( + f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};" + ) + vin = vin[1:] + [reg] + ins.append( + f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};" + ) + elif uop == UOps.LOAD: + if arg.__class__ in (int, float): + ins.append( + f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};" + ) + elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): + dt = ( + ("u16", dtypes.uint16) + if arg[2] == dtypes.bool == out.dtype + else ("u8", dtypes.uint8) + if arg[2] == dtypes.bool + else ("b16", dtypes.float16) + if arg[2] == dtypes.half + else (dtype_to_nvtype[arg[2]], arg[2]) + ) + reg = lang.newreg((out, dt[0]), dtype=dt[1]) + ins.append( + f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];" + ) + render_cast(ins, reg, out) + else: + ins.append( + f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];" + ) + elif uop == UOps.STORE: + if ( + ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) + or arg[2] == dtypes.bool + ): + if arg[2] == dtypes.bool != vin[1].dtype: + prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool) + render_cast(ins, vin[1], prereg) + else: + prereg = vin[1] + reg = lang.newreg( + (prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), + dtype=dtypes.uint16 + if arg[2] == dtypes.bool + else dtypes.float + if arg[2] is None + else arg[2], + ) + render_cast(ins, prereg, reg) + ins.append( + f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};" + ) + else: + ins.append( + f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};" + ) + elif uop == UOps.CAST: + render_cast(ins, vin[0], out) + elif uop == UOps.LABEL: + ins.append(f"{arg}:") + elif uop == UOps.COND_BRANCH: + ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") - ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64", - f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"] - for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",) - ins = ins_prefix + ins - ins += ["ret;", "}"] - return '\n'.join(ins) + ins_prefix = [ + ".version 7.8", + ".target " + arch(), + ".address_size 64", + f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{", + ] + for arg in [ + (dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items() + ]: + ins_prefix.append( + f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;", + ) + ins = ins_prefix + ins + ins += ["ret;", "}"] + return "\n".join(ins) -def uops_to_ptx_asm(function_name:str, uops:List[UOp]): - lang = PTXLanguage() - global_size, local_size = uops_to_asmstyle(lang, function_name, uops) - return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True + +def uops_to_ptx_asm(function_name: str, uops: List[UOp]): + lang = PTXLanguage() + global_size, local_size = uops_to_asmstyle(lang, function_name, uops) + return ( + specialize_to_ptx(lang, function_name), + global_size[::-1], + local_size[::-1], + True, + ) diff --git a/extra/assembly/assembly_rdna.py b/extra/assembly/assembly_rdna.py index ef3adff57..d12610ce0 100644 --- a/extra/assembly/assembly_rdna.py +++ b/extra/assembly/assembly_rdna.py @@ -8,6 +8,7 @@ from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH # ugh, is this really needed? from extra.helpers import enable_early_exec + early_exec = enable_early_exec() boilerplate_start = """ @@ -24,180 +25,359 @@ code_start = """.end_amdhsa_kernel code: """ + # https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst # https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state # RDNA3 is actually a SIMD machine! class RDNACodegen(AssemblyCodegen): - supports_float4: bool = True - supports_float4_alu: bool = True - supports_load3: bool = True - sin_is_sin2pi: bool = True - no_div: bool = True + supports_float4: bool = True + supports_float4_alu: bool = True + supports_load3: bool = True + sin_is_sin2pi: bool = True + no_div: bool = True - def specialize(self, asm) -> Tuple[str, str]: - args = [] - for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'}) - ins = [] + def specialize(self, asm) -> Tuple[str, str]: + args = [] + for i, b in enumerate(self.bufs): + args.append( + { + ".address_space": "global", + ".name": f"buf_{i}", + ".offset": i * 8, + ".size": 8, + ".type_name": b.dtype.name + "*", + ".value_kind": "global_buffer", + } + ) + ins = [] - v_cnt = 3 # v[0:2] is local_xyz - s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz + v_cnt = 3 # v[0:2] is local_xyz + s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz - dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"} - alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma", - BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp", - UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp", - BinaryOps.CMPLT: "cmp_lt"} + dtype_to_rdnatype = { + dtypes.float32: "f32", + dtypes.int64: "i64", + dtypes.int32: "i32", + dtypes.uint64: "u64", + dtypes.bool: "i32", + } + alu = { + BinaryOps.ADD: "add", + BinaryOps.SUB: "sub", + BinaryOps.MUL: "mul", + TernaryOps.MULACC: "fma", + BinaryOps.MAX: "max", + UnaryOps.RECIP: "rcp", + UnaryOps.NOOP: "mov", + UnaryOps.SIN: "sin", + UnaryOps.LOG2: "log", + UnaryOps.EXP2: "exp", + BinaryOps.CMPLT: "cmp_lt", + } - pend_regs:Set[Register] = set() - rtor:Dict[Register, str] = {} - def reg_in(x): - nonlocal pend_regs - #print("reg_in", x, rtor[x], pend_regs) - if x in pend_regs: - #print("clear") - ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)') - pend_regs.clear() - return rtor[x] - def reg_out(x): - return rtor[x] - for uop, out, vin, arg in asm: - if uop == UOps.DEFINE_REGISTER: - if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]: - for i in range(arg[2]): - # TODO: Re-use gaps created by this to avoid wasting registers - align = int(arg[0][0].itemsize / 4) - if arg[0][1]: - s_cnt += s_cnt % align - reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}" - s_cnt += align + pend_regs: Set[Register] = set() + rtor: Dict[Register, str] = {} + + def reg_in(x): + nonlocal pend_regs + # print("reg_in", x, rtor[x], pend_regs) + if x in pend_regs: + # print("clear") + ins.append("s_waitcnt lgkmcnt(0), vmcnt(0)") + pend_regs.clear() + return rtor[x] + + def reg_out(x): + return rtor[x] + + for uop, out, vin, arg in asm: + if uop == UOps.DEFINE_REGISTER: + if arg[0][0] in [ + dtypes.uint32, + dtypes.uint64, + dtypes.int64, + dtypes.int32, + dtypes.float32, + dtypes.float.vec(4), + ]: + for i in range(arg[2]): + # TODO: Re-use gaps created by this to avoid wasting registers + align = int(arg[0][0].itemsize / 4) + if arg[0][1]: + s_cnt += s_cnt % align + reg_name = ( + f"s[{s_cnt}:{s_cnt + align - 1}]" + if align > 1 + else f"s{s_cnt}" + ) + s_cnt += align + else: + v_cnt += v_cnt % align + reg_name = ( + f"v[{v_cnt}:{v_cnt + align - 1}]" + if align > 1 + else f"v{v_cnt}" + ) + v_cnt += align + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name + + if arg[0][0] == dtypes.float.vec(4): + for off in range(4): + reg_name = ( + f"s{s_cnt-align+off}" + if arg[0][1] + else f"v{v_cnt-align+off}" + ) + rtor[ + Register( + f"%{arg[1]}{i}", dtypes.float, False, off=off + ) + ] = reg_name + elif arg[0][0] == dtypes.bool: + for i in range(arg[2]): + reg_name = ( + "scc" if arg[0][1] else "vcc_lo" + ) # `_lo` suffix since we're running wavefront_size=32 + rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name + else: + raise NotImplementedError( + "DEFINE_REGISTER not implemented for arg: ", arg + ) + elif uop == UOps.SPECIAL: + if arg.startswith("buf"): + i = int(arg[3:]) + ins.append(f"s_load_b64 {reg_out(out)}, s[0:1], {i*8}") + pend_regs.add(out) + for r in out.subregs(): + pend_regs.add(r) + elif arg.startswith("gid"): + ins.append(f"v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}") + # the docs lied, this is actually y + if int(arg[3]) == 2: + ins.append("v_bfe_u32 v2, v0, 20, 10") # untested + if int(arg[3]) == 1: + ins.append("v_bfe_u32 v1, v0, 10, 10") + elif int(arg[3]) == 0: + ins.append("v_and_b32_e32 v0, 0x3ff, v0") + # get local size + offset = len(args) * 8 + args.append( + { + ".offset": offset, + ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", + ".size": 8, + } + ) + ins.append(f"s_load_b32 s{2+int(arg[3])}, s[0:1], {offset}") + ins.append("s_waitcnt vmcnt(0) lgkmcnt(0)") + pend_regs.clear() + ins.append( + f"v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}" + ) + ins.append( + f"v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}" + ) + elif uop == UOps.CONST: + if arg == float("inf"): + arg = "0x7f800000" + elif arg == float("-inf"): + arg = "0xff800000" + if out.dtype == dtypes.float.vec(4): + for off in range(4): + ins.append( + f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}" + ) + else: + ins.append( + f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}" + ) + elif uop == UOps.ALU: + if arg in [BinaryOps.CMPLT]: + ins.append( + f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}" + ) + else: + alu_arg = alu[arg] + if arg == TernaryOps.MULACC and out == vin[2]: + alu_arg = "fmac" + vin = vin[0:2] + if out.dtype == dtypes.float.vec(4): + for rr in zip( + *[ + x.subregs() + if x.dtype == dtypes.float.vec(4) + else [x, x, x, x] + for x in [out] + vin + ] + ): + ins.append( + f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}" + ) + else: + ins.append( + f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}" + ) + elif uop == UOps.LOAD: + if out.scalar: + # swap arg order + ins.append( + f"s_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}" + ) + else: + ins.append( + f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}' + ) + pend_regs.add(out) + for r in out.subregs(): + pend_regs.add(r) + elif uop == UOps.STORE: + ins.append( + f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}' + ) + elif uop == UOps.LABEL: + ins.append(f"{arg}:") + elif uop == UOps.COND_BRANCH: + ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}") + elif uop == UOps.CAST: + if vin[0].dtype == dtypes.bool: + if out.dtype == dtypes.float32: + ins.append( + f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}" + ) + else: + raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}") else: - v_cnt += v_cnt % align - reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}" - v_cnt += align - rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name + raise NotImplementedError(uop) - if arg[0][0] == dtypes.float.vec(4): - for off in range(4): - reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}" - rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name - elif arg[0][0] == dtypes.bool: - for i in range(arg[2]): - reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32 - rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name - else: - raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg) - elif uop == UOps.SPECIAL: - if arg.startswith('buf'): - i = int(arg[3:]) - ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}') - pend_regs.add(out) - for r in out.subregs(): pend_regs.add(r) - elif arg.startswith('gid'): - ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}') - # the docs lied, this is actually y - if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested - if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10") - elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0") - # get local size - offset = len(args)*8 - args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8}) - ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}') - ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)') - pend_regs.clear() - ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}') - ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}') - elif uop == UOps.CONST: - if arg == float('inf'): arg = "0x7f800000" - elif arg == float('-inf'): arg = "0xff800000" - if out.dtype == dtypes.float.vec(4): - for off in range(4): - ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}") - else: - ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}") - elif uop == UOps.ALU: - if arg in [BinaryOps.CMPLT]: - ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") - else: - alu_arg = alu[arg] - if arg == TernaryOps.MULACC and out == vin[2]: - alu_arg = "fmac" - vin = vin[0:2] - if out.dtype == dtypes.float.vec(4): - for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]): - ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}") - else: - ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") - elif uop == UOps.LOAD: - if out.scalar: - # swap arg order - ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}') - else: - ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') - pend_regs.add(out) - for r in out.subregs(): pend_regs.add(r) - elif uop == UOps.STORE: - ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') - elif uop == UOps.LABEL: - ins.append(f"{arg}:") - elif uop == UOps.COND_BRANCH: - ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}") - elif uop == UOps.CAST: - if vin[0].dtype == dtypes.bool: - if out.dtype == dtypes.float32: - ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}") - else: - raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}") - else: - raise NotImplementedError(uop) + ins += ["s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", "s_code_end"] - ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end'] + # dual alu group + seen = set() + new_ins = [] + for i, tins in enumerate(ins): + if tins in seen: + continue + if tins.startswith("v_fmac_f32"): + for gins in reversed(ins[i + 1 :]): + if gins in seen: + continue + if gins.startswith("v_fmac_f32"): + r0 = [int(x[1:].strip(",")) for x in tins.split(" ")[1:]] + r1 = [int(x[1:].strip(",")) for x in gins.split(" ")[1:]] + if r0[0] % 2 == r1[0] % 2: + continue + if r0[1] % 2 == r1[1] % 2: + continue + if r0[2] % 2 == r1[2] % 2: + continue + new_ins.append( + tins.replace("v_", "v_dual_") + + " :: " + + gins.replace("v_", "v_dual_") + ) + seen.add(tins) + seen.add(gins) + break + if tins not in seen: + new_ins.append(tins) + ins = new_ins - # dual alu group - seen = set() - new_ins = [] - for i,tins in enumerate(ins): - if tins in seen: continue - if tins.startswith("v_fmac_f32"): - for gins in reversed(ins[i+1:]): - if gins in seen: continue - if gins.startswith("v_fmac_f32"): - r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]] - r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]] - if r0[0]%2 == r1[0]%2: continue - if r0[1]%2 == r1[1]%2: continue - if r0[2]%2 == r1[2]%2: continue - new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_")) - seen.add(tins) - seen.add(gins) - break - if tins not in seen: - new_ins.append(tins) - ins = new_ins + return "code", self.assemble(args, ins, v_cnt, s_cnt) - return 'code', self.assemble(args, ins, v_cnt, s_cnt) + def assemble(self, args, ins, v_cnt, s_cnt): + kernel_desc = { + ".amdhsa_group_segment_fixed_size": 0, + ".amdhsa_private_segment_fixed_size": 0, + ".amdhsa_kernarg_size": 0, + ".amdhsa_next_free_vgpr": v_cnt, # this matters! + ".amdhsa_reserve_vcc": 0, + ".amdhsa_reserve_xnack_mask": 0, + ".amdhsa_next_free_sgpr": s_cnt, + ".amdhsa_float_round_mode_32": 0, + ".amdhsa_float_round_mode_16_64": 0, + ".amdhsa_float_denorm_mode_32": 3, + ".amdhsa_float_denorm_mode_16_64": 3, + ".amdhsa_dx10_clamp": 1, + ".amdhsa_ieee_mode": 1, + ".amdhsa_fp16_overflow": 0, + ".amdhsa_workgroup_processor_mode": 1, + ".amdhsa_memory_ordered": 1, + ".amdhsa_forward_progress": 0, + ".amdhsa_enable_private_segment": 0, + ".amdhsa_system_sgpr_workgroup_id_x": 1, + ".amdhsa_system_sgpr_workgroup_id_y": 1, + ".amdhsa_system_sgpr_workgroup_id_z": 1, + ".amdhsa_system_sgpr_workgroup_info": 0, + ".amdhsa_system_vgpr_workitem_id": 2, # is amdhsa_system_vgpr_workitem_id real? + ".amdhsa_exception_fp_ieee_invalid_op": 0, + ".amdhsa_exception_fp_denorm_src": 0, + ".amdhsa_exception_fp_ieee_div_zero": 0, + ".amdhsa_exception_fp_ieee_overflow": 0, + ".amdhsa_exception_fp_ieee_underflow": 0, + ".amdhsa_exception_fp_ieee_inexact": 0, + ".amdhsa_exception_int_div_zero": 0, + ".amdhsa_user_sgpr_dispatch_ptr": 0, + ".amdhsa_user_sgpr_queue_ptr": 0, + ".amdhsa_user_sgpr_kernarg_segment_ptr": 1, + ".amdhsa_user_sgpr_dispatch_id": 0, + ".amdhsa_user_sgpr_private_segment_size": 0, + ".amdhsa_wavefront_size32": 1, + ".amdhsa_uses_dynamic_stack": 0, + } - def assemble(self, args, ins, v_cnt, s_cnt): - kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0, - '.amdhsa_next_free_vgpr': v_cnt, # this matters! - '.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0, - '.amdhsa_next_free_sgpr': s_cnt, - '.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1, - '.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0, - '.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1, - '.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real? - '.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0, - '.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1, - '.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0} + metadata = { + "amdhsa.kernels": [ + { + ".args": args, + ".group_segment_fixed_size": 0, + ".kernarg_segment_align": 8, + ".kernarg_segment_size": args[-1][".offset"] + args[-1][".size"], + ".language": "OpenCL C", + ".language_version": [1, 2], + ".max_flat_workgroup_size": 256, + ".name": "code", + ".private_segment_fixed_size": 0, + ".sgpr_count": s_cnt, + ".sgpr_spill_count": 0, + ".symbol": "code.kd", + ".uses_dynamic_stack": False, + ".vgpr_count": v_cnt, + ".vgpr_spill_count": 0, + ".wavefront_size": 32, + } + ], + "amdhsa.target": "amdgcn-amd-amdhsa--gfx1100", + "amdhsa.version": [1, 2], + } - metadata = {'amdhsa.kernels': [{'.args': args, - '.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"], - '.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256, - '.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0, - '.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0, - '.wavefront_size': 32}], - 'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]} - - code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata" - obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8"))) - asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj)) - return asm + code = ( + boilerplate_start + + "\n" + + "\n".join("%s %d" % x for x in kernel_desc.items()) + + "\n" + + code_start + + "\n".join(ins) + + "\n.amdgpu_metadata\n" + + yaml.dump(metadata) + + ".end_amdgpu_metadata" + ) + obj = early_exec( + ( + [ + ROCM_LLVM_PATH / "llvm-mc", + "--arch=amdgcn", + "--mcpu=gfx1100", + "--triple=amdgcn-amd-amdhsa", + "--filetype=obj", + "-", + ], + code.encode("utf-8"), + ) + ) + asm = early_exec( + ( + [ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], + obj, + ) + ) + return asm diff --git a/extra/assembly/ptx/test.py b/extra/assembly/ptx/test.py index f30348b8c..03a0c7e45 100644 --- a/extra/assembly/ptx/test.py +++ b/extra/assembly/ptx/test.py @@ -3,8 +3,10 @@ import numpy as np from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer if __name__ == "__main__": - test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32)) - prg = CUDAProgram("test", """ + test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32)) + prg = CUDAProgram( + "test", + """ .version 7.8 .target sm_86 .address_size 64 @@ -17,7 +19,8 @@ if __name__ == "__main__": mov.u32 %r1, 0x40000000; // 2.0 in float st.global.u32 [%rd2], %r1; ret; - }""", binary=True) - prg([1], [1], test) - print(test.toCPU()) - + }""", + binary=True, + ) + prg([1], [1], test) + print(test.toCPU()) diff --git a/extra/assembly/rocm/rdna3/asm.py b/extra/assembly/rocm/rdna3/asm.py index 2f6ad1326..fcc7d97a9 100644 --- a/extra/assembly/rocm/rdna3/asm.py +++ b/extra/assembly/rocm/rdna3/asm.py @@ -3,6 +3,7 @@ import pathlib from hexdump import hexdump from tinygrad.helpers import colored from extra.helpers import enable_early_exec + early_exec = enable_early_exec() from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH @@ -14,13 +15,13 @@ DUAL_ALU = True F32 = True if ENABLE_NON_ASM: - buf = CLBuffer.fromCPU(np.zeros(10, np.float32)) - prg_empty = CLProgram("code", "__kernel void code(__global float *a) { a[0] = 1; }") - asm_real = prg_empty.binary() - with open("/tmp/cc.elf", "wb") as f: - f.write(asm_real) - prg_empty([1], [1], buf, wait=True) - print(buf.toCPU()) + buf = CLBuffer.fromCPU(np.zeros(10, np.float32)) + prg_empty = CLProgram("code", "__kernel void code(__global float *a) { a[0] = 1; }") + asm_real = prg_empty.binary() + with open("/tmp/cc.elf", "wb") as f: + f.write(asm_real) + prg_empty([1], [1], buf, wait=True) + print(buf.toCPU()) print(colored("creating CLBuffer", "green")) buf = CLBuffer.fromCPU(np.zeros(10, np.float32)) @@ -30,51 +31,71 @@ gen = [] FLOPS = 0 MAX_REG = 251 for j in range(1): - if WMMA: - KY, KX = 4, 4 - for y in range(KY): - for x in range(KX): - c = (y*KX+x)*8 - a = (KY*KX*8) + y*8 - b = (KY*KX*8) + (KY*8) + x*8 - gen.append(f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]") - FLOPS += 16*8*2 - else: - for i in range(0, MAX_REG, 6): - if DUAL_ALU: - if F32: - gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}") - FLOPS += 4 - else: - gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}") - FLOPS += 8 - else: - assert F32 - gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}") - gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}") -code = code.replace("// FLOPS", '\n'.join(gen)) + if WMMA: + KY, KX = 4, 4 + for y in range(KY): + for x in range(KX): + c = (y * KX + x) * 8 + a = (KY * KX * 8) + y * 8 + b = (KY * KX * 8) + (KY * 8) + x * 8 + gen.append( + f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]" + ) + FLOPS += 16 * 8 * 2 + else: + for i in range(0, MAX_REG, 6): + if DUAL_ALU: + if F32: + gen.append( + f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}" + ) + FLOPS += 4 + else: + gen.append( + f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}" + ) + FLOPS += 8 + else: + assert F32 + gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}") + gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}") +code = code.replace("// FLOPS", "\n".join(gen)) print(code) # fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa' -object = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8"))) -asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object)) +object = early_exec( + ( + [ + ROCM_LLVM_PATH / "llvm-mc", + "--arch=amdgcn", + "--mcpu=gfx1100", + "--triple=amdgcn-amd-amdhsa", + "--filetype=obj", + "-", + ], + code.encode("utf-8"), + ) +) +asm = early_exec( + ([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object) +) with open("/tmp/cc2.o", "wb") as f: - f.write(object) + f.write(object) with open("/tmp/cc2.elf", "wb") as f: - f.write(asm) + f.write(asm) print(colored("creating CLProgram", "green")) prg = CLProgram("code", asm) print(colored("running program", "green")) G = 512 -FLOPS *= 100000*G*G # loop * global_size +FLOPS *= 100000 * G * G # loop * global_size for i in range(3): - tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True) - print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS") + tm = prg(buf, global_size=[G // 256, G, 1], local_size=[256, 1, 1], wait=True) + print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS") print(colored("transferring buffer", "green")) print(buf.toCPU()) diff --git a/extra/augment.py b/extra/augment.py index 06e7906c7..72e9f1386 100644 --- a/extra/augment.py +++ b/extra/augment.py @@ -2,41 +2,49 @@ import numpy as np from PIL import Image from pathlib import Path import sys + cwd = Path.cwd() sys.path.append(cwd.as_posix()) -sys.path.append((cwd / 'test').as_posix()) +sys.path.append((cwd / "test").as_posix()) from extra.datasets import fetch_mnist from tqdm import trange + def augment_img(X, rotate=10, px=3): - Xaug = np.zeros_like(X) - for i in trange(len(X)): - im = Image.fromarray(X[i]) - im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC) - w, h = X.shape[1:] - #upper left, lower left, lower right, upper right - quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0]) - im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) - Xaug[i] = im - return Xaug + Xaug = np.zeros_like(X) + for i in trange(len(X)): + im = Image.fromarray(X[i]) + im = im.rotate(np.random.randint(-rotate, rotate), resample=Image.BICUBIC) + w, h = X.shape[1:] + # upper left, lower left, lower right, upper right + quad = np.random.randint(-px, px, size=(8)) + np.array([0, 0, 0, h, w, h, w, 0]) + im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) + Xaug[i] = im + return Xaug + if __name__ == "__main__": - import matplotlib.pyplot as plt - X_train, Y_train, X_test, Y_test = fetch_mnist() - X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) - X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) - X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10) - fig, a = plt.subplots(2,len(X)) - Xaug = augment_img(X) - for i in range(len(X)): - a[0][i].imshow(X[i], cmap='gray') - a[1][i].imshow(Xaug[i],cmap='gray') - a[0][i].axis('off') - a[1][i].axis('off') - plt.show() + import matplotlib.pyplot as plt - #create some nice gifs for doc?! - for i in range(10): - im = Image.fromarray(X_train[7353+i]) - im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))] - im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0) + X_train, Y_train, X_test, Y_test = fetch_mnist() + X_train = X_train.reshape(-1, 28, 28).astype(np.uint8) + X_test = X_test.reshape(-1, 28, 28).astype(np.uint8) + X = np.vstack([X_train[:1]] * 10 + [X_train[1:2]] * 10) + fig, a = plt.subplots(2, len(X)) + Xaug = augment_img(X) + for i in range(len(X)): + a[0][i].imshow(X[i], cmap="gray") + a[1][i].imshow(Xaug[i], cmap="gray") + a[0][i].axis("off") + a[1][i].axis("off") + plt.show() + + # create some nice gifs for doc?! + for i in range(10): + im = Image.fromarray(X_train[7353 + i]) + im_aug = [ + Image.fromarray(x) for x in augment_img(np.array([X_train[7353 + i]] * 100)) + ] + im.save( + f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0 + ) diff --git a/extra/autopad.py b/extra/autopad.py index d23a2e973..eddb2f317 100644 --- a/extra/autopad.py +++ b/extra/autopad.py @@ -37,4 +37,4 @@ lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32)) lin.hand_coded_optimizations() lin.linearize() -run_linearizer(lin) \ No newline at end of file +run_linearizer(lin) diff --git a/extra/datasets/__init__.py b/extra/datasets/__init__.py index 92398bf23..844a71bca 100644 --- a/extra/datasets/__init__.py +++ b/extra/datasets/__init__.py @@ -3,41 +3,82 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, fetch + def fetch_mnist(tensors=False): - parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy() - BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https - X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32) - Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:] - X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32) - Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:] - if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test) - else: return X_train, Y_train, X_test, Y_test + parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy() + BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https + X_train = ( + parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:] + .reshape((-1, 28 * 28)) + .astype(np.float32) + ) + Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:] + X_test = ( + parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:] + .reshape((-1, 28 * 28)) + .astype(np.float32) + ) + Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:] + if tensors: + return ( + Tensor(X_train).reshape(-1, 1, 28, 28), + Tensor(Y_train), + Tensor(X_test).reshape(-1, 1, 28, 28), + Tensor(Y_test), + ) + else: + return X_train, Y_train, X_test, Y_test + cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618] cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628] + def fetch_cifar(): - X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8) - Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64) - X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8) - Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64) + X_train = Tensor.empty( + 50000, 3 * 32 * 32, device=f"disk:/tmp/cifar_train_x", dtype=dtypes.uint8 + ) + Y_train = Tensor.empty(50000, device=f"disk:/tmp/cifar_train_y", dtype=dtypes.int64) + X_test = Tensor.empty( + 10000, 3 * 32 * 32, device=f"disk:/tmp/cifar_test_x", dtype=dtypes.uint8 + ) + Y_test = Tensor.empty(10000, device=f"disk:/tmp/cifar_test_y", dtype=dtypes.int64) - if not os.path.isfile("/tmp/cifar_extracted"): - def _load_disk_tensor(X, Y, db_list): - idx = 0 - for db in db_list: - x, y = db[b'data'], np.array(db[b'labels']) - assert x.shape[0] == y.shape[0] - X[idx:idx+x.shape[0]].assign(x) - Y[idx:idx+x.shape[0]].assign(y) - idx += x.shape[0] - assert idx == X.shape[0] and X.shape[0] == Y.shape[0] + if not os.path.isfile("/tmp/cifar_extracted"): - print("downloading and extracting CIFAR...") - fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz') - tt = tarfile.open(fn, mode='r:gz') - _load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]) - _load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]) - open("/tmp/cifar_extracted", "wb").close() + def _load_disk_tensor(X, Y, db_list): + idx = 0 + for db in db_list: + x, y = db[b"data"], np.array(db[b"labels"]) + assert x.shape[0] == y.shape[0] + X[idx : idx + x.shape[0]].assign(x) + Y[idx : idx + x.shape[0]].assign(y) + idx += x.shape[0] + assert idx == X.shape[0] and X.shape[0] == Y.shape[0] - return X_train, Y_train, X_test, Y_test + print("downloading and extracting CIFAR...") + fn = fetch("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz") + tt = tarfile.open(fn, mode="r:gz") + _load_disk_tensor( + X_train, + Y_train, + [ + pickle.load( + tt.extractfile(f"cifar-10-batches-py/data_batch_{i}"), + encoding="bytes", + ) + for i in range(1, 6) + ], + ) + _load_disk_tensor( + X_test, + Y_test, + [ + pickle.load( + tt.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes" + ) + ], + ) + open("/tmp/cifar_extracted", "wb").close() + + return X_train, Y_train, X_test, Y_test diff --git a/extra/datasets/coco.py b/extra/datasets/coco.py index 0952e3770..ffb5f732c 100644 --- a/extra/datasets/coco.py +++ b/extra/datasets/coco.py @@ -8,192 +8,207 @@ from examples.mask_rcnn import Masker from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval -iou = _mask.iou -merge = _mask.merge +iou = _mask.iou +merge = _mask.merge frPyObjects = _mask.frPyObjects BASEDIR = pathlib.Path(__file__).parent / "COCO" BASEDIR.mkdir(exist_ok=True) -def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows} + +def create_dict(key_row, val_row, rows): + return {row[key_row]: row[val_row] for row in rows} -if not pathlib.Path(BASEDIR/'val2017').is_dir(): - fn = fetch('http://images.cocodataset.org/zips/val2017.zip') - with zipfile.ZipFile(fn, 'r') as zip_ref: - zip_ref.extractall(BASEDIR) - fn.unlink() +if not pathlib.Path(BASEDIR / "val2017").is_dir(): + fn = fetch("http://images.cocodataset.org/zips/val2017.zip") + with zipfile.ZipFile(fn, "r") as zip_ref: + zip_ref.extractall(BASEDIR) + fn.unlink() -if not pathlib.Path(BASEDIR/'annotations').is_dir(): - fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip') - with zipfile.ZipFile(fn, 'r') as zip_ref: - zip_ref.extractall(BASEDIR) - fn.unlink() +if not pathlib.Path(BASEDIR / "annotations").is_dir(): + fn = fetch("http://images.cocodataset.org/annotations/annotations_trainval2017.zip") + with zipfile.ZipFile(fn, "r") as zip_ref: + zip_ref.extractall(BASEDIR) + fn.unlink() -with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f: - annotations_raw = json.loads(f.read()) -images = annotations_raw['images'] -categories = annotations_raw['categories'] -annotations = annotations_raw['annotations'] -file_name_to_id = create_dict('file_name', 'id', images) -id_to_width = create_dict('id', 'width', images) -id_to_height = create_dict('id', 'height', images) -json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)} -contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()} +with open(BASEDIR / "annotations/instances_val2017.json", "r") as f: + annotations_raw = json.loads(f.read()) +images = annotations_raw["images"] +categories = annotations_raw["categories"] +annotations = annotations_raw["annotations"] +file_name_to_id = create_dict("file_name", "id", images) +id_to_width = create_dict("id", "width", images) +id_to_height = create_dict("id", "height", images) +json_category_id_to_contiguous_id = {v["id"]: i + 1 for i, v in enumerate(categories)} +contiguous_category_id_to_json_id = { + v: k for k, v in json_category_id_to_contiguous_id.items() +} def encode(bimask): - if len(bimask.shape) == 3: - return _mask.encode(bimask) - elif len(bimask.shape) == 2: - h, w = bimask.shape - return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] + if len(bimask.shape) == 3: + return _mask.encode(bimask) + elif len(bimask.shape) == 2: + h, w = bimask.shape + return _mask.encode(bimask.reshape((h, w, 1), order="F"))[0] + def decode(rleObjs): - if type(rleObjs) == list: - return _mask.decode(rleObjs) - else: - return _mask.decode([rleObjs])[:,:,0] + if type(rleObjs) == list: + return _mask.decode(rleObjs) + else: + return _mask.decode([rleObjs])[:, :, 0] + def area(rleObjs): - if type(rleObjs) == list: - return _mask.area(rleObjs) - else: - return _mask.area([rleObjs])[0] + if type(rleObjs) == list: + return _mask.area(rleObjs) + else: + return _mask.area([rleObjs])[0] + def toBbox(rleObjs): - if type(rleObjs) == list: - return _mask.toBbox(rleObjs) - else: - return _mask.toBbox([rleObjs])[0] + if type(rleObjs) == list: + return _mask.toBbox(rleObjs) + else: + return _mask.toBbox([rleObjs])[0] def convert_prediction_to_coco_bbox(file_name, prediction): - coco_results = [] - try: - original_id = file_name_to_id[file_name] - if len(prediction) == 0: - return coco_results + coco_results = [] + try: + original_id = file_name_to_id[file_name] + if len(prediction) == 0: + return coco_results - image_width = id_to_width[original_id] - image_height = id_to_height[original_id] - prediction = prediction.resize((image_width, image_height)) - prediction = prediction.convert("xywh") + image_width = id_to_width[original_id] + image_height = id_to_height[original_id] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") - boxes = prediction.bbox.numpy().tolist() - scores = prediction.get_field("scores").numpy().tolist() - labels = prediction.get_field("labels").numpy().tolist() + boxes = prediction.bbox.numpy().tolist() + scores = prediction.get_field("scores").numpy().tolist() + labels = prediction.get_field("labels").numpy().tolist() - mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] + mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + except Exception as e: + print(file_name, e) + return coco_results - coco_results.extend( - [ - { - "image_id": original_id, - "category_id": mapped_labels[k], - "bbox": box, - "score": scores[k], - } - for k, box in enumerate(boxes) - ] - ) - except Exception as e: - print(file_name, e) - return coco_results masker = Masker(threshold=0.5, padding=1) + def convert_prediction_to_coco_mask(file_name, prediction): - coco_results = [] - try: - original_id = file_name_to_id[file_name] - if len(prediction) == 0: - return coco_results + coco_results = [] + try: + original_id = file_name_to_id[file_name] + if len(prediction) == 0: + return coco_results - image_width = id_to_width[original_id] - image_height = id_to_height[original_id] - prediction = prediction.resize((image_width, image_height)) - masks = prediction.get_field("mask") + image_width = id_to_width[original_id] + image_height = id_to_height[original_id] + prediction = prediction.resize((image_width, image_height)) + masks = prediction.get_field("mask") - scores = prediction.get_field("scores").numpy().tolist() - labels = prediction.get_field("labels").numpy().tolist() + scores = prediction.get_field("scores").numpy().tolist() + labels = prediction.get_field("labels").numpy().tolist() - masks = masker([masks], [prediction])[0].numpy() + masks = masker([masks], [prediction])[0].numpy() - rles = [ - encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] - for mask in masks - ] - for rle in rles: - rle["counts"] = rle["counts"].decode("utf-8") + rles = [ + encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") - mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] - - coco_results.extend( - [ - { - "image_id": original_id, - "category_id": mapped_labels[k], - "segmentation": rle, - "score": scores[k], - } - for k, rle in enumerate(rles) - ] - ) - except Exception as e: - print(file_name, e) - return coco_results + mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + except Exception as e: + print(file_name, e) + return coco_results def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False): - path = pathlib.Path(json_result_file) - if rm and path.exists(): path.unlink() - with open(path, "a") as f: - for s in coco_results: - f.write(json.dumps(s)) - f.write('\n') + path = pathlib.Path(json_result_file) + if rm and path.exists(): + path.unlink() + with open(path, "a") as f: + for s in coco_results: + f.write(json.dumps(s)) + f.write("\n") + def remove_dup(l): - seen = set() - seen_add = seen.add - return [x for x in l if not (x in seen or seen_add(x))] + seen = set() + seen_add = seen.add + return [x for x in l if not (x in seen or seen_add(x))] + class NpEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.integer): - return int(obj) - if isinstance(obj, np.floating): - return float(obj) - if isinstance(obj, np.ndarray): - return obj.tolist() - return super(NpEncoder, self).default(obj) + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super(NpEncoder, self).default(obj) def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"): - coco_results = [] - with open(json_result_file, "r") as f: - for line in f: - coco_results.append(json.loads(line)) + coco_results = [] + with open(json_result_file, "r") as f: + for line in f: + coco_results.append(json.loads(line)) - coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json')) - set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results]) - unique_list = [json.loads(s) for s in set_of_json] + coco_gt = COCO(str(BASEDIR / "annotations/instances_val2017.json")) + set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results]) + unique_list = [json.loads(s) for s in set_of_json] - with open(f'{json_result_file}.flattend', "w") as f: - json.dump(unique_list, f) + with open(f"{json_result_file}.flattend", "w") as f: + json.dump(unique_list, f) + + coco_dt = coco_gt.loadRes(str(f"{json_result_file}.flattend")) + coco_eval = COCOeval(coco_gt, coco_dt, iou_type) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval - coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend')) - coco_eval = COCOeval(coco_gt, coco_dt, iou_type) - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - return coco_eval def iterate(files, bs=1): - batch = [] - for file in files: - batch.append(file) - if len(batch) >= bs: yield batch; batch = [] - if len(batch) > 0: yield batch; batch = [] + batch = [] + for file in files: + batch.append(file) + if len(batch) >= bs: + yield batch + batch = [] + if len(batch) > 0: + yield batch + batch = [] diff --git a/extra/datasets/imagenet.py b/extra/datasets/imagenet.py index dde32a5e4..9a0a3ce31 100644 --- a/extra/datasets/imagenet.py +++ b/extra/datasets/imagenet.py @@ -7,47 +7,56 @@ import functools, pathlib BASEDIR = pathlib.Path(__file__).parent / "imagenet" ci = json.load(open(BASEDIR / "imagenet_class_index.json")) -cir = {v[0]: int(k) for k,v in ci.items()} +cir = {v[0]: int(k) for k, v in ci.items()} + @functools.lru_cache(None) def get_train_files(): - train_files = open(BASEDIR / "train_files").read().strip().split("\n") - return [(BASEDIR / "train" / x) for x in train_files] + train_files = open(BASEDIR / "train_files").read().strip().split("\n") + return [(BASEDIR / "train" / x) for x in train_files] + @functools.lru_cache(None) def get_val_files(): - val_files = glob.glob(str(BASEDIR / "val/*/*")) - return val_files + val_files = glob.glob(str(BASEDIR / "val/*/*")) + return val_files -#rrc = transforms.RandomResizedCrop(224) + +# rrc = transforms.RandomResizedCrop(224) import torchvision.transforms.functional as F + + def image_load(fn): - img = Image.open(fn).convert('RGB') - img = F.resize(img, 256, Image.BILINEAR) - img = F.center_crop(img, 224) - ret = np.array(img) - return ret + img = Image.open(fn).convert("RGB") + img = F.resize(img, 256, Image.BILINEAR) + img = F.center_crop(img, 224) + ret = np.array(img) + return ret + def iterate(bs=32, val=True, shuffle=True): - files = get_val_files() if val else get_train_files() - order = list(range(0, len(files))) - if shuffle: random.shuffle(order) - from multiprocessing import Pool - p = Pool(16) - for i in range(0, len(files), bs): - X = p.map(image_load, [files[i] for i in order[i:i+bs]]) - Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]] - yield (np.array(X), np.array(Y)) + files = get_val_files() if val else get_train_files() + order = list(range(0, len(files))) + if shuffle: + random.shuffle(order) + from multiprocessing import Pool + + p = Pool(16) + for i in range(0, len(files), bs): + X = p.map(image_load, [files[i] for i in order[i : i + bs]]) + Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]] + yield (np.array(X), np.array(Y)) + def fetch_batch(bs, val=False): - files = get_val_files() if val else get_train_files() - samp = np.random.randint(0, len(files), size=(bs)) - files = [files[i] for i in samp] - X = [image_load(x) for x in files] - Y = [cir[x.split("/")[0]] for x in files] - return np.array(X), np.array(Y) + files = get_val_files() if val else get_train_files() + samp = np.random.randint(0, len(files), size=(bs)) + files = [files[i] for i in samp] + X = [image_load(x) for x in files] + Y = [cir[x.split("/")[0]] for x in files] + return np.array(X), np.array(Y) + if __name__ == "__main__": - X,Y = fetch_batch(64) - print(X.shape, Y) - + X, Y = fetch_batch(64) + print(X.shape, Y) diff --git a/extra/datasets/imagenet_download.py b/extra/datasets/imagenet_download.py index 5c01c72f5..43a253a95 100644 --- a/extra/datasets/imagenet_download.py +++ b/extra/datasets/imagenet_download.py @@ -4,48 +4,92 @@ from pathlib import Path from tqdm import tqdm import tarfile, os + def imagenet_extract(file, path, small=False): - with tarfile.open(name=file) as tar: - if small: # Show progressbar only for big files - for member in tar.getmembers(): tar.extract(path=path, member=member) - else: - for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member) - tar.close() + with tarfile.open(name=file) as tar: + if small: # Show progressbar only for big files + for member in tar.getmembers(): + tar.extract(path=path, member=member) + else: + for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): + tar.extract(path=path, member=member) + tar.close() + def imagenet_prepare_val(): - # Read in the labels file - with open(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt", 'r') as f: - labels = f.read().splitlines() - f.close() - # Get a list of images - images = os.listdir(Path(__file__).parent / "imagenet" / "val") - images.sort() - # Create folders and move files into those - for co,dir in enumerate(labels): - os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True) - os.replace(Path(__file__).parent / "imagenet" / "val" / images[co], Path(__file__).parent / "imagenet" / "val" / dir / images[co]) - os.remove(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt") + # Read in the labels file + with open( + Path(__file__).parent + / "imagenet" + / "imagenet_2012_validation_synset_labels.txt", + "r", + ) as f: + labels = f.read().splitlines() + f.close() + # Get a list of images + images = os.listdir(Path(__file__).parent / "imagenet" / "val") + images.sort() + # Create folders and move files into those + for co, dir in enumerate(labels): + os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True) + os.replace( + Path(__file__).parent / "imagenet" / "val" / images[co], + Path(__file__).parent / "imagenet" / "val" / dir / images[co], + ) + os.remove( + Path(__file__).parent + / "imagenet" + / "imagenet_2012_validation_synset_labels.txt" + ) + def imagenet_prepare_train(): - images = os.listdir(Path(__file__).parent / "imagenet" / "train") - for co,tarf in enumerate(images): - # for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file - if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file(): - images[co] = tarf[:-4] # remove .tar from extracted tar files - os.makedirs(Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True) - imagenet_extract(Path(__file__).parent / "imagenet" / "train" / tarf, Path(__file__).parent/ "imagenet" / "train" / images[co], small=True) - os.remove(Path(__file__).parent / "imagenet" / "train" / tarf) + images = os.listdir(Path(__file__).parent / "imagenet" / "train") + for co, tarf in enumerate(images): + # for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file + if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file(): + images[co] = tarf[:-4] # remove .tar from extracted tar files + os.makedirs( + Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True + ) + imagenet_extract( + Path(__file__).parent / "imagenet" / "train" / tarf, + Path(__file__).parent / "imagenet" / "train" / images[co], + small=True, + ) + os.remove(Path(__file__).parent / "imagenet" / "train" / tarf) + if __name__ == "__main__": - os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True) - os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True) - os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True) - fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json") - fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt") - fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB - imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val") - imagenet_prepare_val() - if os.getenv('IMGNET_TRAIN', None) is not None: - fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB! - imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train") - imagenet_prepare_train() + os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True) + os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True) + os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True) + fetch( + "https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", + Path(__file__).parent / "imagenet" / "imagenet_class_index.json", + ) + fetch( + "https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", + Path(__file__).parent + / "imagenet" + / "imagenet_2012_validation_synset_labels.txt", + ) + fetch( + "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", + Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", + ) # 7GB + imagenet_extract( + Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", + Path(__file__).parent / "imagenet" / "val", + ) + imagenet_prepare_val() + if os.getenv("IMGNET_TRAIN", None) is not None: + fetch( + "https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", + Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", + ) # 138GB! + imagenet_extract( + Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", + Path(__file__).parent / "imagenet" / "train", + ) + imagenet_prepare_train() diff --git a/extra/datasets/kits19.py b/extra/datasets/kits19.py index 70204fbbf..6e23fdf92 100644 --- a/extra/datasets/kits19.py +++ b/extra/datasets/kits19.py @@ -23,109 +23,199 @@ mv kits extra/datasets ``` """ + @functools.lru_cache(None) def get_val_files(): - data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text() - return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]) + data = fetch( + "https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt" + ).read_text() + return sorted( + [x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")] + ) + def load_pair(file_path): - image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz") - image_spacings = image.header["pixdim"][1:4].tolist() - image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8) - image, label = np.expand_dims(image, 0), np.expand_dims(label, 0) - return image, label, image_spacings + image, label = nib.load(file_path / "imaging.nii.gz"), nib.load( + file_path / "segmentation.nii.gz" + ) + image_spacings = image.header["pixdim"][1:4].tolist() + image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype( + np.uint8 + ) + image, label = np.expand_dims(image, 0), np.expand_dims(label, 0) + return image, label, image_spacings + def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)): - if image_spacings != target_spacing: - spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:]) - new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist() - image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True) - label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest") - image = np.squeeze(image.numpy(), axis=0) - label = np.squeeze(label.numpy(), axis=0) - return image, label + if image_spacings != target_spacing: + spc_arr, targ_arr, shp_arr = ( + np.array(image_spacings), + np.array(target_spacing), + np.array(image.shape[1:]), + ) + new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist() + image = F.interpolate( + torch.from_numpy(np.expand_dims(image, axis=0)), + size=new_shape, + mode="trilinear", + align_corners=True, + ) + label = F.interpolate( + torch.from_numpy(np.expand_dims(label, axis=0)), + size=new_shape, + mode="nearest", + ) + image = np.squeeze(image.numpy(), axis=0) + label = np.squeeze(label.numpy(), axis=0) + return image, label + def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9): - image = np.clip(image, min_clip, max_clip) - image = (image - mean) / std - return image + image = np.clip(image, min_clip, max_clip) + image = (image - mean) / std + return image + def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)): - current_shape = image.shape[1:] - bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)] - paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)] - image = np.pad(image, paddings, mode="edge") - label = np.pad(label, paddings, mode="edge") - return image, label + current_shape = image.shape[1:] + bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)] + paddings = [(0, 0)] + [ + (bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3) + ] + image = np.pad(image, paddings, mode="edge") + label = np.pad(label, paddings, mode="edge") + return image, label + def preprocess(file_path): - image, label, image_spacings = load_pair(file_path) - image, label = resample3d(image, label, image_spacings) - image = normal_intensity(image.copy()) - image, label = pad_to_min_shape(image, label) - return image, label + image, label, image_spacings = load_pair(file_path) + image, label = resample3d(image, label, image_spacings) + image = normal_intensity(image.copy()) + image, label = pad_to_min_shape(image, label) + return image, label + def iterate(val=True, shuffle=False): - if not val: raise NotImplementedError - files = get_val_files() - order = list(range(0, len(files))) - if shuffle: random.shuffle(order) - for file in files: - X, Y = preprocess(file) - X = np.expand_dims(X, axis=0) - yield (X, Y) + if not val: + raise NotImplementedError + files = get_val_files() + order = list(range(0, len(files))) + if shuffle: + random.shuffle(order) + for file in files: + X, Y = preprocess(file) + X = np.expand_dims(X, axis=0) + yield (X, Y) + def gaussian_kernel(n, std): - gaussian_1d = signal.gaussian(n, std) - gaussian_2d = np.outer(gaussian_1d, gaussian_1d) - gaussian_3d = np.outer(gaussian_2d, gaussian_1d) - gaussian_3d = gaussian_3d.reshape(n, n, n) - gaussian_3d = np.cbrt(gaussian_3d) - gaussian_3d /= gaussian_3d.max() - return gaussian_3d + gaussian_1d = signal.gaussian(n, std) + gaussian_2d = np.outer(gaussian_1d, gaussian_1d) + gaussian_3d = np.outer(gaussian_2d, gaussian_1d) + gaussian_3d = gaussian_3d.reshape(n, n, n) + gaussian_3d = np.cbrt(gaussian_3d) + gaussian_3d /= gaussian_3d.max() + return gaussian_3d -def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3): - bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)] - bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)] - paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0] - return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings -def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5): - from tinygrad.jit import TinyJit - mdl_run = TinyJit(lambda x: model(x).realize()) - image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:]) - strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)] - bounds = [image_shape[i] % strides[i] for i in range(dim)] - bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)] - inputs = inputs[ - ..., - bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), - bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), - bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), - ] - labels = labels[ - ..., - bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), - bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), - bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), - ] - inputs, paddings = pad_input(inputs, roi_shape, strides) - padded_shape = inputs.shape[2:] - size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] - result = np.zeros((1, 3, *padded_shape), dtype=np.float32) - norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32) - norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]) - norm_patch = np.expand_dims(norm_patch, axis=0) - for i in range(0, strides[0] * size[0], strides[0]): - for j in range(0, strides[1] * size[1], strides[1]): - for k in range(0, strides[2] * size[2], strides[2]): - out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy() - result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch - norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch - result /= norm_map - result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]] - return result, labels +def pad_input( + volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3 +): + bounds = [ + (strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim) + ] + bounds = [ + bounds[i] + if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] + else bounds[i] + strides[i] + for i in range(dim) + ] + paddings = [ + bounds[2] // 2, + bounds[2] - bounds[2] // 2, + bounds[1] // 2, + bounds[1] - bounds[1] // 2, + bounds[0] // 2, + bounds[0] - bounds[0] // 2, + 0, + 0, + 0, + 0, + ] + return ( + F.pad( + torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val + ).numpy(), + paddings, + ) + + +def sliding_window_inference( + model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5 +): + from tinygrad.jit import TinyJit + + mdl_run = TinyJit(lambda x: model(x).realize()) + image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:]) + strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)] + bounds = [image_shape[i] % strides[i] for i in range(dim)] + bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)] + inputs = inputs[ + ..., + bounds[0] // 2 : image_shape[0] - (bounds[0] - bounds[0] // 2), + bounds[1] // 2 : image_shape[1] - (bounds[1] - bounds[1] // 2), + bounds[2] // 2 : image_shape[2] - (bounds[2] - bounds[2] // 2), + ] + labels = labels[ + ..., + bounds[0] // 2 : image_shape[0] - (bounds[0] - bounds[0] // 2), + bounds[1] // 2 : image_shape[1] - (bounds[1] - bounds[1] // 2), + bounds[2] // 2 : image_shape[2] - (bounds[2] - bounds[2] // 2), + ] + inputs, paddings = pad_input(inputs, roi_shape, strides) + padded_shape = inputs.shape[2:] + size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] + result = np.zeros((1, 3, *padded_shape), dtype=np.float32) + norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32) + norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]) + norm_patch = np.expand_dims(norm_patch, axis=0) + for i in range(0, strides[0] * size[0], strides[0]): + for j in range(0, strides[1] * size[1], strides[1]): + for k in range(0, strides[2] * size[2], strides[2]): + out = mdl_run( + Tensor( + inputs[ + ..., + i : roi_shape[0] + i, + j : roi_shape[1] + j, + k : roi_shape[2] + k, + ] + ) + ).numpy() + result[ + ..., + i : roi_shape[0] + i, + j : roi_shape[1] + j, + k : roi_shape[2] + k, + ] += ( + out * norm_patch + ) + norm_map[ + ..., + i : roi_shape[0] + i, + j : roi_shape[1] + j, + k : roi_shape[2] + k, + ] += norm_patch + result /= norm_map + result = result[ + ..., + paddings[4] : image_shape[0] + paddings[4], + paddings[2] : image_shape[1] + paddings[2], + paddings[0] : image_shape[2] + paddings[0], + ] + return result, labels + if __name__ == "__main__": - for X, Y in iterate(): - print(X.shape, Y.shape) + for X, Y in iterate(): + print(X.shape, Y.shape) diff --git a/extra/datasets/librispeech.py b/extra/datasets/librispeech.py index 434d7a059..b441b5ca2 100644 --- a/extra/datasets/librispeech.py +++ b/extra/datasets/librispeech.py @@ -17,66 +17,88 @@ Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recog """ BASEDIR = pathlib.Path(__file__).parent / "librispeech" with open(BASEDIR / "dev-clean-wav.json") as f: - ci = json.load(f) + ci = json.load(f) -FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0) +FILTER_BANK = np.expand_dims( + librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0 +) WINDOW = librosa.filters.get_window("hann", 320) + def feature_extract(x, x_lens): - x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32) + x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32) - # pre-emphasis - x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1) + # pre-emphasis + x = np.concatenate( + (np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1 + ) - # stft - x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect") - x = np.stack((x.real, x.imag), axis=-1) + # stft + x = librosa.stft( + x, + n_fft=512, + window=WINDOW, + hop_length=160, + win_length=320, + center=True, + pad_mode="reflect", + ) + x = np.stack((x.real, x.imag), axis=-1) - # power spectrum - x = (x**2).sum(-1) + # power spectrum + x = (x**2).sum(-1) - # mel filter bank - x = np.matmul(FILTER_BANK, x) + # mel filter bank + x = np.matmul(FILTER_BANK, x) - # log - x = np.log(x + 1e-20) + # log + x = np.log(x + 1e-20) - # feature splice - seq = [x] - for i in range(1, 3): - tmp = np.zeros_like(x) - tmp[:, :, :-i] = x[:, :, i:] - seq.append(tmp) - features = np.concatenate(seq, axis=1)[:, :, ::3] + # feature splice + seq = [x] + for i in range(1, 3): + tmp = np.zeros_like(x) + tmp[:, :, :-i] = x[:, :, i:] + seq.append(tmp) + features = np.concatenate(seq, axis=1)[:, :, ::3] - # normalize - features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32) - features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32) - for i in range(features.shape[0]): - features_mean[i, :] = features[i, :, :x_lens[i]].mean(axis=1) - features_std[i, :] = features[i, :, :x_lens[i]].std(axis=1, ddof=1) - features_std += 1e-5 - features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2) + # normalize + features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32) + features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32) + for i in range(features.shape[0]): + features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1) + features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1) + features_std += 1e-5 + features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims( + features_std, 2 + ) + + return features.transpose(2, 0, 1), x_lens.astype(np.float32) - return features.transpose(2, 0, 1), x_lens.astype(np.float32) def load_wav(file): - sample = soundfile.read(file)[0].astype(np.float32) - return sample, sample.shape[0] + sample = soundfile.read(file)[0].astype(np.float32) + return sample, sample.shape[0] + def iterate(bs=1, start=0): - print(f"there are {len(ci)} samples in the dataset") - for i in range(start, len(ci), bs): - samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]]) - samples = list(samples) - # pad to same length - max_len = max(sample_lens) - for j in range(len(samples)): - samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant") - samples, sample_lens = np.array(samples), np.array(sample_lens) + print(f"there are {len(ci)} samples in the dataset") + for i in range(start, len(ci), bs): + samples, sample_lens = zip( + *[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]] + ) + samples = list(samples) + # pad to same length + max_len = max(sample_lens) + for j in range(len(samples)): + samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant") + samples, sample_lens = np.array(samples), np.array(sample_lens) + + yield feature_extract(samples, sample_lens), np.array( + [v["transcript"] for v in ci[i : i + bs]] + ) - yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]]) if __name__ == "__main__": - X, Y = next(iterate()) - print(X[0].shape, Y.shape) + X, Y = next(iterate()) + print(X[0].shape, Y.shape) diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index 97bd2e846..a4d724eaa 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -12,153 +12,467 @@ import concurrent.futures BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf" BUCKET_NAME = "open-images-dataset" -BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv" -MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv" -MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana', - 'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle', - 'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot', - 'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread', - 'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry', - 'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart', - 'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken', - 'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin', - 'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store', - 'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard', - 'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly', - 'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant', - 'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork', - 'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses', - 'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar', - 'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels', - 'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard', - 'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair', - 'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream', - 'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite', - 'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse', - 'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror', - 'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule', - 'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building', - 'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen', - 'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow', - 'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle', - 'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion', - 'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard', - 'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon', - 'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light', - 'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan', - 'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television', - 'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower', - 'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase', - 'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch', - 'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman', - 'Zebra', 'Zucchini', +BBOX_ANNOTATIONS_URL = ( + "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv" +) +MAP_CLASSES_URL = ( + "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv" +) +MLPERF_CLASSES = [ + "Airplane", + "Antelope", + "Apple", + "Backpack", + "Balloon", + "Banana", + "Barrel", + "Baseball bat", + "Baseball glove", + "Bee", + "Beer", + "Bench", + "Bicycle", + "Bicycle helmet", + "Bicycle wheel", + "Billboard", + "Book", + "Bookcase", + "Boot", + "Bottle", + "Bowl", + "Bowling equipment", + "Box", + "Boy", + "Brassiere", + "Bread", + "Broccoli", + "Bronze sculpture", + "Bull", + "Bus", + "Bust", + "Butterfly", + "Cabinetry", + "Cake", + "Camel", + "Camera", + "Candle", + "Candy", + "Cannon", + "Canoe", + "Carrot", + "Cart", + "Castle", + "Cat", + "Cattle", + "Cello", + "Chair", + "Cheese", + "Chest of drawers", + "Chicken", + "Christmas tree", + "Coat", + "Cocktail", + "Coffee", + "Coffee cup", + "Coffee table", + "Coin", + "Common sunflower", + "Computer keyboard", + "Computer monitor", + "Convenience store", + "Cookie", + "Countertop", + "Cowboy hat", + "Crab", + "Crocodile", + "Cucumber", + "Cupboard", + "Curtain", + "Deer", + "Desk", + "Dinosaur", + "Dog", + "Doll", + "Dolphin", + "Door", + "Dragonfly", + "Drawer", + "Dress", + "Drum", + "Duck", + "Eagle", + "Earrings", + "Egg (Food)", + "Elephant", + "Falcon", + "Fedora", + "Flag", + "Flowerpot", + "Football", + "Football helmet", + "Fork", + "Fountain", + "French fries", + "French horn", + "Frog", + "Giraffe", + "Girl", + "Glasses", + "Goat", + "Goggles", + "Goldfish", + "Gondola", + "Goose", + "Grape", + "Grapefruit", + "Guitar", + "Hamburger", + "Handbag", + "Harbor seal", + "Headphones", + "Helicopter", + "High heels", + "Hiking equipment", + "Horse", + "House", + "Houseplant", + "Human arm", + "Human beard", + "Human body", + "Human ear", + "Human eye", + "Human face", + "Human foot", + "Human hair", + "Human hand", + "Human head", + "Human leg", + "Human mouth", + "Human nose", + "Ice cream", + "Jacket", + "Jeans", + "Jellyfish", + "Juice", + "Kitchen & dining room table", + "Kite", + "Lamp", + "Lantern", + "Laptop", + "Lavender (Plant)", + "Lemon", + "Light bulb", + "Lighthouse", + "Lily", + "Lion", + "Lipstick", + "Lizard", + "Man", + "Maple", + "Microphone", + "Mirror", + "Mixing bowl", + "Mobile phone", + "Monkey", + "Motorcycle", + "Muffin", + "Mug", + "Mule", + "Mushroom", + "Musical keyboard", + "Necklace", + "Nightstand", + "Office building", + "Orange", + "Owl", + "Oyster", + "Paddle", + "Palm tree", + "Parachute", + "Parrot", + "Pen", + "Penguin", + "Personal flotation device", + "Piano", + "Picture frame", + "Pig", + "Pillow", + "Pizza", + "Plate", + "Platter", + "Porch", + "Poster", + "Pumpkin", + "Rabbit", + "Rifle", + "Roller skates", + "Rose", + "Salad", + "Sandal", + "Saucer", + "Saxophone", + "Scarf", + "Sea lion", + "Sea turtle", + "Sheep", + "Shelf", + "Shirt", + "Shorts", + "Shrimp", + "Sink", + "Skateboard", + "Ski", + "Skull", + "Skyscraper", + "Snake", + "Sock", + "Sofa bed", + "Sparrow", + "Spider", + "Spoon", + "Sports uniform", + "Squirrel", + "Stairs", + "Stool", + "Strawberry", + "Street light", + "Studio couch", + "Suit", + "Sun hat", + "Sunglasses", + "Surfboard", + "Sushi", + "Swan", + "Swimming pool", + "Swimwear", + "Tank", + "Tap", + "Taxi", + "Tea", + "Teddy bear", + "Television", + "Tent", + "Tie", + "Tiger", + "Tin can", + "Tire", + "Toilet", + "Tomato", + "Tortoise", + "Tower", + "Traffic light", + "Train", + "Tripod", + "Truck", + "Trumpet", + "Umbrella", + "Van", + "Vase", + "Vehicle registration plate", + "Violin", + "Wall clock", + "Waste container", + "Watch", + "Whale", + "Wheel", + "Wheelchair", + "Whiteboard", + "Window", + "Wine", + "Wine glass", + "Woman", + "Zebra", + "Zucchini", ] + def openimages(): - ann_file = BASEDIR / "validation/labels/openimages-mlperf.json" - if not ann_file.is_file(): - fetch_openimages(ann_file) - return ann_file + ann_file = BASEDIR / "validation/labels/openimages-mlperf.json" + if not ann_file.is_file(): + fetch_openimages(ann_file) + return ann_file + # this slows down the conversion a lot! # maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py -def extract_dims(path): return Image.open(path).size[::-1] +def extract_dims(path): + return Image.open(path).size[::-1] -def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES): - output_path.parent.mkdir(parents=True, exist_ok=True) - cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)] - categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"]) - class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner") - annotations = annotations[np.isin(annotations["ImageID"], image_list)] - annotations = annotations.merge(class_map, on="LabelName", how="inner") - annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0] - annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand") - # Images - imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None} - for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows()) - ] +def export_to_coco( + class_map, + annotations, + image_list, + dataset_path, + output_path, + classes=MLPERF_CLASSES, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)] + categories_map = pd.DataFrame( + [(i, c) for i, c in enumerate(classes)], + columns=["category_id", "category_name"], + ) + class_map = class_map.merge( + categories_map, left_on="DisplayName", right_on="category_name", how="inner" + ) + annotations = annotations[np.isin(annotations["ImageID"], image_list)] + annotations = annotations.merge(class_map, on="LabelName", how="inner") + annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0] + annotations[["height", "width"]] = annotations.apply( + lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), + axis=1, + result_type="expand", + ) - # Annotations - annots = [] - for i, row in annotations.iterrows(): - xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]] - x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h - coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h} - coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]}) - coco_annot["iscrowd"] = int(row["IsGroupOf"]) - annots.append(coco_annot) + # Images + imgs = [ + { + "id": int(id + 1), + "file_name": f"{image_id}.jpg", + "height": row["height"], + "width": row["width"], + "license": None, + "coco_url": None, + } + for (id, image_id), row in ( + annotations.groupby(["image_id", "ImageID"]).first().iterrows() + ) + ] + + # Annotations + annots = [] + for i, row in annotations.iterrows(): + xmin, ymin, xmax, ymax, img_w, img_h = [ + row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"] + ] + x, y, w, h = ( + xmin * img_w, + ymin * img_h, + (xmax - xmin) * img_w, + (ymax - ymin) * img_h, + ) + coco_annot = { + "id": int(i) + 1, + "image_id": int(row["image_id"] + 1), + "category_id": int(row["category_id"]), + "bbox": [x, y, w, h], + "area": w * h, + } + coco_annot.update( + { + k: row[k] + for k in [ + "IsOccluded", + "IsInside", + "IsDepiction", + "IsTruncated", + "IsGroupOf", + ] + } + ) + coco_annot["iscrowd"] = int(row["IsGroupOf"]) + annots.append(coco_annot) + + info = {"dataset": "openimages_mlperf", "version": "v6"} + coco_annotations = { + "info": info, + "licenses": [], + "categories": cats, + "images": imgs, + "annotations": annots, + } + with open(output_path, "w") as fp: + json.dump(coco_annotations, fp) - info = {"dataset": "openimages_mlperf", "version": "v6"} - coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots} - with open(output_path, "w") as fp: - json.dump(coco_annotations, fp) def get_image_list(class_map, annotations, classes=MLPERF_CLASSES): - labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"] - image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique() - return image_ids + labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"] + image_ids = annotations[np.isin(annotations["LabelName"], labels)][ + "ImageID" + ].unique() + return image_ids + def download_image(bucket, image_id, data_dir): - try: - bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg") - except botocore.exceptions.ClientError as exception: - sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}") + try: + bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg") + except botocore.exceptions.ClientError as exception: + sys.exit( + f"ERROR when downloading image `validation/{image_id}`: {str(exception)}" + ) + def fetch_openimages(output_fn): - bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME) + bucket = boto3.resource( + "s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED) + ).Bucket(BUCKET_NAME) - annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data" - annotations_dir.mkdir(parents=True, exist_ok=True) - data_dir.mkdir(parents=True, exist_ok=True) + annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data" + annotations_dir.mkdir(parents=True, exist_ok=True) + data_dir.mkdir(parents=True, exist_ok=True) - annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1] - fetch(BBOX_ANNOTATIONS_URL, annotations_fn) - annotations = pd.read_csv(annotations_fn) + annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split("/")[-1] + fetch(BBOX_ANNOTATIONS_URL, annotations_fn) + annotations = pd.read_csv(annotations_fn) - classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1] - fetch(MAP_CLASSES_URL, classmap_fn) - class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"]) + classmap_fn = annotations_dir / MAP_CLASSES_URL.split("/")[-1] + fetch(MAP_CLASSES_URL, classmap_fn) + class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"]) - image_list = get_image_list(class_map, annotations) + image_list = get_image_list(class_map, annotations) - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list] - for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))): - t.set_description(f"Downloading images") - future.result() + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(download_image, bucket, image_id, data_dir) + for image_id in image_list + ] + for future in ( + t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list)) + ): + t.set_description(f"Downloading images") + future.result() + + print("Converting annotations to COCO format...") + export_to_coco(class_map, annotations, image_list, data_dir, output_fn) - print("Converting annotations to COCO format...") - export_to_coco(class_map, annotations, image_list, data_dir, output_fn) def image_load(fn): - img_folder = BASEDIR / "validation/data" - img = Image.open(img_folder / fn).convert('RGB') - import torchvision.transforms.functional as F - ret = F.resize(img, size=(800, 800)) - ret = np.array(ret) - return ret, img.size[::-1] + img_folder = BASEDIR / "validation/data" + img = Image.open(img_folder / fn).convert("RGB") + import torchvision.transforms.functional as F + + ret = F.resize(img, size=(800, 800)) + ret = np.array(ret) + return ret, img.size[::-1] + def prepare_target(annotations, img_id, img_size): - boxes = [annot["bbox"] for annot in annotations] - boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) - boxes[:, 2:] += boxes[:, :2] - boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1]) - boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0]) - keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) - boxes = boxes[keep] - classes = [annot["category_id"] for annot in annotations] - classes = np.array(classes, dtype=np.int64) - classes = classes[keep] - return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size} + boxes = [annot["bbox"] for annot in annotations] + boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1]) + boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0]) + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = [annot["category_id"] for annot in annotations] + classes = np.array(classes, dtype=np.int64) + classes = classes[keep] + return { + "boxes": boxes, + "labels": classes, + "image_id": img_id, + "image_size": img_size, + } + def iterate(coco, bs=8): - image_ids = sorted(coco.imgs.keys()) - for i in range(0, len(image_ids), bs): - X, targets = [], [] - for img_id in image_ids[i:i+bs]: - x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"]) - X.append(x) - annotations = coco.loadAnns(coco.getAnnIds(img_id)) - targets.append(prepare_target(annotations, img_id, original_size)) - yield np.array(X), targets + image_ids = sorted(coco.imgs.keys()) + for i in range(0, len(image_ids), bs): + X, targets = [], [] + for img_id in image_ids[i : i + bs]: + x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"]) + X.append(x) + annotations = coco.loadAnns(coco.getAnnIds(img_id)) + targets.append(prepare_target(annotations, img_id, original_size)) + yield np.array(X), targets diff --git a/extra/datasets/preprocess_imagenet.py b/extra/datasets/preprocess_imagenet.py index 692539794..735eaf311 100644 --- a/extra/datasets/preprocess_imagenet.py +++ b/extra/datasets/preprocess_imagenet.py @@ -3,20 +3,25 @@ from tinygrad.tensor import Tensor from extra.datasets.imagenet import iterate, get_val_files if __name__ == "__main__": - #sz = len(get_val_files()) - sz = 32*100 - X,Y = None, None + # sz = len(get_val_files()) + sz = 32 * 100 + X, Y = None, None - idx = 0 - for x,y in iterate(shuffle=False): - print(x.shape, y.shape, x.dtype, y.dtype) - assert x.shape[0] == y.shape[0] - bs = x.shape[0] - if X is None: - X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8) - Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64) - print(X.shape, Y.shape) - X[idx:idx+bs].assign(x) - Y[idx:idx+bs].assign(y) - idx += bs - if idx >= sz: break + idx = 0 + for x, y in iterate(shuffle=False): + print(x.shape, y.shape, x.dtype, y.dtype) + assert x.shape[0] == y.shape[0] + bs = x.shape[0] + if X is None: + X = Tensor.empty( + sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8 + ) + Y = Tensor.empty( + sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64 + ) + print(X.shape, Y.shape) + X[idx : idx + bs].assign(x) + Y[idx : idx + bs].assign(y) + idx += bs + if idx >= sz: + break diff --git a/extra/datasets/squad.py b/extra/datasets/squad.py index 7be38bff4..86e236011 100644 --- a/extra/datasets/squad.py +++ b/extra/datasets/squad.py @@ -6,143 +6,164 @@ import numpy as np from tinygrad.helpers import fetch BASEDIR = Path(__file__).parent / "squad" + + def init_dataset(): - os.makedirs(BASEDIR, exist_ok=True) - fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json") - with open(BASEDIR / "dev-v1.1.json") as f: - data = json.load(f)["data"] + os.makedirs(BASEDIR, exist_ok=True) + fetch( + "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", + BASEDIR / "dev-v1.1.json", + ) + with open(BASEDIR / "dev-v1.1.json") as f: + data = json.load(f)["data"] - examples = [] - for article in data: - for paragraph in article["paragraphs"]: - text = paragraph["context"] - doc_tokens = [] - prev_is_whitespace = True - for c in text: - if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: - prev_is_whitespace = True - else: - if prev_is_whitespace: - doc_tokens.append(c) - else: - doc_tokens[-1] += c - prev_is_whitespace = False + examples = [] + for article in data: + for paragraph in article["paragraphs"]: + text = paragraph["context"] + doc_tokens = [] + prev_is_whitespace = True + for c in text: + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False - for qa in paragraph["qas"]: - qa_id = qa["id"] - q_text = qa["question"] + for qa in paragraph["qas"]: + qa_id = qa["id"] + q_text = qa["question"] + + examples.append( + { + "id": qa_id, + "question": q_text, + "context": doc_tokens, + "answers": list(map(lambda x: x["text"], qa["answers"])), + } + ) + return examples - examples.append({ - "id": qa_id, - "question": q_text, - "context": doc_tokens, - "answers": list(map(lambda x: x["text"], qa["answers"])) - }) - return examples def _check_is_max_context(doc_spans, cur_span_index, position): - best_score, best_span_index = None, None - for di, (doc_start, doc_length) in enumerate(doc_spans): - end = doc_start + doc_length - 1 - if position < doc_start: - continue - if position > end: - continue - num_left_context = position - doc_start - num_right_context = end - position - score = min(num_left_context, num_right_context) + 0.01 * doc_length - if best_score is None or score > best_score: - best_score = score - best_span_index = di - return cur_span_index == best_span_index + best_score, best_span_index = None, None + for di, (doc_start, doc_length) in enumerate(doc_spans): + end = doc_start + doc_length - 1 + if position < doc_start: + continue + if position > end: + continue + num_left_context = position - doc_start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_length + if best_score is None or score > best_score: + best_score = score + best_span_index = di + return cur_span_index == best_span_index + def convert_example_to_features(example, tokenizer): - query_tokens = tokenizer.tokenize(example["question"]) + query_tokens = tokenizer.tokenize(example["question"]) - if len(query_tokens) > 64: - query_tokens = query_tokens[:64] + if len(query_tokens) > 64: + query_tokens = query_tokens[:64] - tok_to_orig_index = [] - orig_to_tok_index = [] - all_doc_tokens = [] - for i, token in enumerate(example["context"]): - orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer.tokenize(token) - for sub_token in sub_tokens: - tok_to_orig_index.append(i) - all_doc_tokens.append(sub_token) + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for i, token in enumerate(example["context"]): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) - max_tokens_for_doc = 384 - len(query_tokens) - 3 + max_tokens_for_doc = 384 - len(query_tokens) - 3 - doc_spans = [] - start_offset = 0 - while start_offset < len(all_doc_tokens): - length = len(all_doc_tokens) - start_offset - length = min(length, max_tokens_for_doc) - doc_spans.append((start_offset, length)) - if start_offset + length == len(all_doc_tokens): - break - start_offset += min(length, 128) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + length = min(length, max_tokens_for_doc) + doc_spans.append((start_offset, length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, 128) - outputs = [] - for di, (doc_start, doc_length) in enumerate(doc_spans): - tokens = [] - token_to_orig_map = {} - token_is_max_context = {} - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in query_tokens: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) + outputs = [] + for di, (doc_start, doc_length) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in query_tokens: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) - for i in range(doc_length): - split_token_index = doc_start + i - token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] - token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index) - tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) + for i in range(doc_length): + split_token_index = doc_start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] + token_is_max_context[len(tokens)] = _check_is_max_context( + doc_spans, di, split_token_index + ) + tokens.append(all_doc_tokens[split_token_index]) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) - input_ids = tokenizer.convert_tokens_to_ids(tokens) - input_mask = [1] * len(input_ids) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) - while len(input_ids) < 384: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) + while len(input_ids) < 384: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) - assert len(input_ids) == 384 - assert len(input_mask) == 384 - assert len(segment_ids) == 384 + assert len(input_ids) == 384 + assert len(input_mask) == 384 + assert len(segment_ids) == 384 - outputs.append({ - "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32), - "input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32), - "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32), - "token_to_orig_map": token_to_orig_map, - "token_is_max_context": token_is_max_context, - "tokens": tokens, - }) + outputs.append( + { + "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32), + "input_mask": np.expand_dims(np.array(input_mask), 0).astype( + np.float32 + ), + "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype( + np.float32 + ), + "token_to_orig_map": token_to_orig_map, + "token_is_max_context": token_is_max_context, + "tokens": tokens, + } + ) + + return outputs - return outputs def iterate(tokenizer, start=0): - examples = init_dataset() - print(f"there are {len(examples)} pairs in the dataset") + examples = init_dataset() + print(f"there are {len(examples)} pairs in the dataset") + + for i in range(start, len(examples)): + example = examples[i] + features = convert_example_to_features(example, tokenizer) + # we need to yield all features here as the f1 score is the maximum over all features + yield features, example - for i in range(start, len(examples)): - example = examples[i] - features = convert_example_to_features(example, tokenizer) - # we need to yield all features here as the f1 score is the maximum over all features - yield features, example if __name__ == "__main__": - tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt")) + tokenizer = BertTokenizer( + str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt") + ) - X, Y = next(iterate(tokenizer)) - print(" ".join(X[0]["tokens"])) - print(X[0]["input_ids"].shape, Y) + X, Y = next(iterate(tokenizer)) + print(" ".join(X[0]["tokens"])) + print(X[0]["input_ids"].shape, Y) diff --git a/extra/dist/__init__.py b/extra/dist/__init__.py index e25e58c1e..949c5f7f9 100644 --- a/extra/dist/__init__.py +++ b/extra/dist/__init__.py @@ -5,56 +5,70 @@ from tinygrad.helpers import DEBUG, getenv import multiprocessing as mp import os + # this needs to be called before everything else if you are using distributed def preinit(): - os.environ["DELAYED_RUNTIME_INIT"] = "1" - mp.set_start_method("spawn") + os.environ["DELAYED_RUNTIME_INIT"] = "1" + mp.set_start_method("spawn") + # out-of-band communication/synchronization class _OOB: - def __init__(self, pipes:List[Tuple[Connection, Connection]]): - self.pipes = pipes + def __init__(self, pipes: List[Tuple[Connection, Connection]]): + self.pipes = pipes + + # send some data to a target rank, blocks until data is received + def send(self, data: Any, target_rank: int): + self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data) + + # receive some data from a target rank, blocks until data is received + def recv(self, target_rank: int) -> Any: + return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv() - # send some data to a target rank, blocks until data is received - def send(self, data:Any, target_rank:int): - self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data) - # receive some data from a target rank, blocks until data is received - def recv(self, target_rank:int) -> Any: - return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv() OOB: Optional[_OOB] = None -def init_oob(world_size:int): - os.environ["WORLD_SIZE"] = str(world_size) - global OOB - OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)]) +def init_oob(world_size: int): + os.environ["WORLD_SIZE"] = str(world_size) + + global OOB + OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)]) + # this runs in the spawned process so we can do all the delayed runtime initialization -def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()): - # setup the rank - os.environ["RANK"] = str(rank) +def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()): + # setup the rank + os.environ["RANK"] = str(rank) - # setup out of band communication - global OOB - OOB = oob + # setup out of band communication + global OOB + OOB = oob - # do specific runtime initialization for distributed - from tinygrad import Device - device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1]) - if "GPU" in device: - from tinygrad.runtime.ops_gpu import CL - CL.post_init(device_num) - elif "HIP" in device: - os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num) - if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}") + # do specific runtime initialization for distributed + from tinygrad import Device - # convert device to be process specific - Device.DEFAULT = device.split(":")[0] if "GPU" in device else device + device, device_num = Device.canonicalize(device), 0 if ":" not in device else int( + device.split(":")[-1] + ) + if "GPU" in device: + from tinygrad.runtime.ops_gpu import CL + + CL.post_init(device_num) + elif "HIP" in device: + os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str( + device_num + ) + if DEBUG >= 1: + print(f"distributed process {rank} initialized runtime for device {device}") + + # convert device to be process specific + Device.DEFAULT = device.split(":")[0] if "GPU" in device else device + + fn(*args) - fn(*args) # wrapper around mp.Process that initializes the runtime -def spawn(rank:int, device:str, fn:Callable, args=()) -> mp.Process: - (p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start() - return p +def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process: + (p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start() + return p diff --git a/extra/dist/collectives.py b/extra/dist/collectives.py index f30da34b3..0139d56c6 100644 --- a/extra/dist/collectives.py +++ b/extra/dist/collectives.py @@ -3,38 +3,41 @@ from tinygrad.helpers import getenv from extra.dist import world -def allreduce(t:Tensor) -> Tensor: - RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE") - # flatten - flattened = t.flatten() +def allreduce(t: Tensor) -> Tensor: + RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE") - # pad to evenly divide - if flattened.shape[0] % WORLD_SIZE != 0: - flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))) + # flatten + flattened = t.flatten() - # chunk - chunks = flattened.chunk(WORLD_SIZE, dim=0) + # pad to evenly divide + if flattened.shape[0] % WORLD_SIZE != 0: + flattened = Tensor.cat( + flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)) + ) - next_rank = (RANK + 1) % WORLD_SIZE - prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE + # chunk + chunks = flattened.chunk(WORLD_SIZE, dim=0) - # scatter reduce - current_chunk_index = RANK - for _ in range(WORLD_SIZE - 1): - world.send(chunks[current_chunk_index], next_rank) - current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE - recv_buf = Tensor.empty(*chunks[current_chunk_index].shape) - world.recv(recv_buf, prev_rank) - chunks[current_chunk_index] += recv_buf + next_rank = (RANK + 1) % WORLD_SIZE + prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE - # gather - current_chunk_index = (RANK + 1) % WORLD_SIZE - for _ in range(WORLD_SIZE - 1): - world.send(chunks[current_chunk_index], next_rank) - current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE - recv_buf = Tensor.empty(*chunks[current_chunk_index].shape) - world.recv(recv_buf, prev_rank) - chunks[current_chunk_index].assign(recv_buf) + # scatter reduce + current_chunk_index = RANK + for _ in range(WORLD_SIZE - 1): + world.send(chunks[current_chunk_index], next_rank) + current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE + recv_buf = Tensor.empty(*chunks[current_chunk_index].shape) + world.recv(recv_buf, prev_rank) + chunks[current_chunk_index] += recv_buf - return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape) + # gather + current_chunk_index = (RANK + 1) % WORLD_SIZE + for _ in range(WORLD_SIZE - 1): + world.send(chunks[current_chunk_index], next_rank) + current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE + recv_buf = Tensor.empty(*chunks[current_chunk_index].shape) + world.recv(recv_buf, prev_rank) + chunks[current_chunk_index].assign(recv_buf) + + return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape) diff --git a/extra/dist/world.py b/extra/dist/world.py index 476cf4b29..08e87df06 100644 --- a/extra/dist/world.py +++ b/extra/dist/world.py @@ -4,111 +4,154 @@ from multiprocessing import shared_memory from tinygrad.helpers import DEBUG, colored, getenv from tinygrad.lazy import LazyBuffer from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut + try: - import gpuctypes.hip as hip - from tinygrad.runtime.ops_hip import RawHIPBuffer, check -except: RawHIPBuffer = None + import gpuctypes.hip as hip + from tinygrad.runtime.ops_hip import RawHIPBuffer, check +except: + RawHIPBuffer = None from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.jit import CacheCollector from tinygrad.tensor import Tensor, Function import numpy as np + # match the function signature of JITRunner so we can put it in the cache def __send_rb(args, variables=None, wait=False, jit=False): - x, target_rank, y = args[:3] - if RawHIPBuffer and x.__class__ is RawHIPBuffer: - check(hip.hipSetDevice(x._device)) - check(hip.hipDeviceSynchronize()) - else: - if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np)) - else: y.fromCPU(x.toCPU()) - dist.OOB.send(None, target_rank) - if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}") + x, target_rank, y = args[:3] + if RawHIPBuffer and x.__class__ is RawHIPBuffer: + check(hip.hipSetDevice(x._device)) + check(hip.hipDeviceSynchronize()) + else: + if isinstance(x, RawBufferCopyInOut): + x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np)) + else: + y.fromCPU(x.toCPU()) + dist.OOB.send(None, target_rank) + if DEBUG >= 2: + print( + f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}" + ) + def __recv_rb(args, variables=None, wait=False, jit=False): - x, target_rank, y = args[:3] - dist.OOB.recv(target_rank) - if RawHIPBuffer and x.__class__ is RawHIPBuffer: - x._transfer(y) - elif isinstance(x, RawBuffer): x._copyin(y.toCPU()) - else: x.fromCPU(y.toCPU()) - if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}") + x, target_rank, y = args[:3] + dist.OOB.recv(target_rank) + if RawHIPBuffer and x.__class__ is RawHIPBuffer: + x._transfer(y) + elif isinstance(x, RawBuffer): + x._copyin(y.toCPU()) + else: + x.fromCPU(y.toCPU()) + if DEBUG >= 2: + print( + f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}" + ) + # send a rawbuffer from out rank to the target rank -def _send_rb(x:RawBuffer, target_rank:int): - if RawHIPBuffer and x.__class__ is RawHIPBuffer: - # send ipc handle - check(hip.hipSetDevice(x._device)) - check(hip.hipDeviceSynchronize()) - check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf)) - dist.OOB.send((handle, x._device), target_rank) +def _send_rb(x: RawBuffer, target_rank: int): + if RawHIPBuffer and x.__class__ is RawHIPBuffer: + # send ipc handle + check(hip.hipSetDevice(x._device)) + check(hip.hipDeviceSynchronize()) + check( + hip.hipIpcGetMemHandle( + ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf + ) + ) + dist.OOB.send((handle, x._device), target_rank) - # jit support - x._allocator = None # need to disconnect allocator for sent buffers - CacheCollector.add(__send_rb, [x, target_rank, None], {}) - else: - # create shared memory - shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name - s.close() + # jit support + x._allocator = None # need to disconnect allocator for sent buffers + CacheCollector.add(__send_rb, [x, target_rank, None], {}) + else: + # create shared memory + shm_name = ( + s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize) + ).name + s.close() - # copy the buffer into shared memory - y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:"+shm_name) - # fast path when we can directly copyout - if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np)) - else: y.fromCPU(x.toCPU()) + # copy the buffer into shared memory + y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name) + # fast path when we can directly copyout + if isinstance(x, RawBufferCopyInOut): + x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np)) + else: + y.fromCPU(x.toCPU()) - dist.OOB.send(shm_name, target_rank) + dist.OOB.send(shm_name, target_rank) + + # jit support + CacheCollector.add(__send_rb, [x, target_rank, y], {}) + if DEBUG >= 2: + print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}") - # jit support - CacheCollector.add(__send_rb, [x, target_rank, y], {}) - if DEBUG >= 2: print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}") # receive a rawbuffer from the target rank -def _recv_rb(x:RawBuffer, target_rank:int): - if RawHIPBuffer and isinstance(x, RawHIPBuffer): - # open ipc handle - handle, y_device = dist.OOB.recv(target_rank) - check(hip.hipSetDevice(y_device)) - check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0)) +def _recv_rb(x: RawBuffer, target_rank: int): + if RawHIPBuffer and isinstance(x, RawHIPBuffer): + # open ipc handle + handle, y_device = dist.OOB.recv(target_rank) + check(hip.hipSetDevice(y_device)) + check( + hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0) + ) - # build a new buffer - y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None) - x._transfer(y) + # build a new buffer + y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None) + x._transfer(y) - CacheCollector.add(__recv_rb, [x, target_rank, y], {}) - else: - shm_name = dist.OOB.recv(target_rank) - y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:"+shm_name) + CacheCollector.add(__recv_rb, [x, target_rank, y], {}) + else: + shm_name = dist.OOB.recv(target_rank) + y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name) - # fast path when we can directly copyin - if isinstance(x, RawBuffer): x._copyin(y.toCPU()) - else: x.fromCPU(y.toCPU()) + # fast path when we can directly copyin + if isinstance(x, RawBuffer): + x._copyin(y.toCPU()) + else: + x.fromCPU(y.toCPU()) + + # jit support + CacheCollector.add(__recv_rb, [x, target_rank, y], {}) + if DEBUG >= 2: + print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}") - # jit support - CacheCollector.add(__recv_rb, [x, target_rank, y], {}) - if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}") # sends a lazybuffer from our rank to the target rank -def _send_lb(x:LazyBuffer, target_rank:int) -> None: - assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized" - _send_rb(x.realized, target_rank) +def _send_lb(x: LazyBuffer, target_rank: int) -> None: + assert ( + x.st.contiguous and x.realized + ), "sending buffer must be contiguous and realized" + _send_rb(x.realized, target_rank) + # receive a lazybuffer from the target rank -def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer: - assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized" - _recv_rb(x.realized, target_rank) - return x - -class Send(Function): - def forward(self, x:LazyBuffer, target_rank:int) -> LazyBuffer: - self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype - _send_lb(x, target_rank) +def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer: + assert ( + x.st.contiguous and x.realized + ), "receiving buffer must be contiguous and realized" + _recv_rb(x.realized, target_rank) return x -class Recv(Function): - def forward(self, x:LazyBuffer, target_rank:int) -> LazyBuffer: - self.target_rank = target_rank - return _recv_lb(x, target_rank) -def send(x:Tensor, target_rank:int) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank) -def recv(x:Tensor, target_rank:int) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank) +class Send(Function): + def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer: + self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype + _send_lb(x, target_rank) + return x + + +class Recv(Function): + def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer: + self.target_rank = target_rank + return _recv_lb(x, target_rank) + + +def send(x: Tensor, target_rank: int) -> Tensor: + return Send.apply(x.contiguous().realize(), target_rank=target_rank) + + +def recv(x: Tensor, target_rank: int) -> Tensor: + return Recv.apply(x.contiguous().realize(), target_rank=target_rank) diff --git a/extra/dump_cache.py b/extra/dump_cache.py index 325d2bd22..8f77d496f 100644 --- a/extra/dump_cache.py +++ b/extra/dump_cache.py @@ -2,20 +2,25 @@ import sys, sqlite3, pickle from tinygrad.helpers import CACHEDB if __name__ == "__main__": - fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB - conn = sqlite3.connect(fn) - cur = conn.cursor() - cur.execute("SELECT name FROM sqlite_master WHERE type='table'") - for f in cur.fetchall(): - table = f[0] - cur2 = conn.cursor() - cur2.execute(f"SELECT COUNT(*) FROM {table}") - cnt = cur2.fetchone()[0] - print(f"{table:20s} : {cnt}") + fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB + conn = sqlite3.connect(fn) + cur = conn.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='table'") + for f in cur.fetchall(): + table = f[0] + cur2 = conn.cursor() + cur2.execute(f"SELECT COUNT(*) FROM {table}") + cnt = cur2.fetchone()[0] + print(f"{table:20s} : {cnt}") - cur3 = conn.cursor() - cur3.execute(f"SELECT * FROM {table} LIMIT 10") - for f in cur3.fetchall(): - v = pickle.loads(f[-1]) - print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50]) - #print(f"{len(k):10d}, {sk} -> {v}") + cur3 = conn.cursor() + cur3.execute(f"SELECT * FROM {table} LIMIT 10") + for f in cur3.fetchall(): + v = pickle.loads(f[-1]) + print( + " ", + len(f[0]) if isinstance(f[0], str) else f[0], + f[1:-1], + str(v)[0:50], + ) + # print(f"{len(k):10d}, {sk} -> {v}") diff --git a/extra/export_model.py b/extra/export_model.py index 448c41a91..23f26aafc 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -7,77 +7,190 @@ import json EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"] -def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: - functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 - for ji in run.jit_cache: - fxn = ji.prg - functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same - cargs = [] - for i,arg in enumerate(ji.rawbufs): - key = id(arg) - if key not in bufs: - if key in special_names: - bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key) - else: - bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key) - bufnum += 1 - if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name - cargs.append(bufs[key][0]) - statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size)) - return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save +def compile_net( + run: TinyJit, special_names: Dict[int, str] +) -> Tuple[ + Dict[str, str], + List[Tuple[str, List[str], List[int]]], + Dict[str, Tuple[int, DType, int]], + Dict[str, Tensor], +]: + functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 + for ji in run.jit_cache: + fxn = ji.prg + functions[ + fxn.name + ] = fxn.prg # NOTE: this assumes all with the same name are the same + cargs = [] + for i, arg in enumerate(ji.rawbufs): + key = id(arg) + if key not in bufs: + if key in special_names: + bufs[key] = ( + special_names[key], + arg.size * arg.dtype.itemsize, + arg.dtype, + key, + ) + else: + bufs[key] = ( + f"buf_{bufnum}", + arg.size * arg.dtype.itemsize, + arg.dtype, + key, + ) + bufnum += 1 + if i > 0: + bufs_to_save[ + bufs[key][0] + ] = arg # if first usage of a buffer is not an output, and it's not a special name + cargs.append(bufs[key][0]) + statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size)) -def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]: - assert hasattr(model, "forward") or callable(model), "model needs a forward function" - @TinyJit - def run(*x): - out = model.forward(*x) if hasattr(model, "forward") else model(*x) - assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export" - out = [out] if isinstance(out, Tensor) else out - return [o.realize() for o in out] + return ( + functions, + statements, + {name: (size, dtype, key) for (name, size, dtype, key) in bufs.values()}, + bufs_to_save, + ) - # twice to run the JIT - for _ in range(2): the_output = run(*args) - special_names = {} - # hack to put the inputs back - for (j,i),idx in run.input_replace.items(): - realized_input = args[idx].lazydata.realized - run.jit_cache[j].rawbufs[i] = realized_input - special_names[id(realized_input)] = f'input{idx}' +def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]: + assert hasattr(model, "forward") or callable( + model + ), "model needs a forward function" - # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret) - for i, output in enumerate(the_output): - special_names[id(output.lazydata.realized)] = f'output{i}' - return run, special_names + @TinyJit + def run(*x): + out = model.forward(*x) if hasattr(model, "forward") else model(*x) + assert ( + isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor) + ), "model output must be a Tensor, tuple, or a list of Tensors for export" + out = [out] if isinstance(out, Tensor) else out + return [o.realize() for o in out] -def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str: - from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER - cprog = [CLANG_PROGRAM_HEADER] + # twice to run the JIT + for _ in range(2): + the_output = run(*args) + special_names = {} - for name,cl in bufs_to_save.items(): - weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)]) - cprog.append(f"unsigned char {name}_data[] = \"{weight}\";") + # hack to put the inputs back + for (j, i), idx in run.input_replace.items(): + realized_input = args[idx].lazydata.realized + run.jit_cache[j].rawbufs[i] = realized_input + special_names[id(realized_input)] = f"input{idx}" - inputs = ", ".join([f'float* {input}' for input in input_names]) - outputs = ", ".join([f'float* {output}' for output in output_names]) - cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']] - cprog += list(functions.values()) - cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"] - return '\n'.join(cprog) + # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret) + for i, output in enumerate(the_output): + special_names[id(output.lazydata.realized)] = f"output{i}" + return run, special_names -def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]: - kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) - kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements]) - kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) - _bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()]) - gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)]) - input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) - gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)]) - outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)]) - output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) - output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) - return f""" + +def export_model_clang( + functions: Dict[str, str], + statements: Dict[str, Tuple[str, int, int]], + bufs: Dict[str, Tuple[str, int, int]], + bufs_to_save: Dict[str, Tensor], + input_names: List[str], + output_names: List[str], +) -> str: + from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER + + cprog = [CLANG_PROGRAM_HEADER] + + for name, cl in bufs_to_save.items(): + weight = "".join(["\\x%02X" % x for x in bytes(cl._buf)]) + cprog.append(f'unsigned char {name}_data[] = "{weight}";') + + inputs = ", ".join([f"float* {input}" for input in input_names]) + outputs = ", ".join([f"float* {output}" for output in output_names]) + cprog += [ + f"float {name}[{len}];" + if name not in bufs_to_save + else f"float *{name} = (float *){name}_data;" + for name, (len, dtype, _key) in bufs.items() + if name not in ["input", "outputs"] + ] + cprog += list(functions.values()) + cprog += ( + [f"void net({inputs}, {outputs}) {{"] + + [ + f"{name}({', '.join(args)});" + for (name, args, _global_size, _local_size) in statements + ] + + ["}"] + ) + return "\n".join(cprog) + + +def export_model_webgpu( + functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names +) -> Tuple[str, int, int]: + kernel_code = "\n\n".join( + [ + f"const {key} = `{code.replace(key, 'main')}`;" + for key, code in functions.items() + ] + ) + kernel_names = ", ".join( + [name for (name, _args, _global_size, _local_size) in statements] + ) + kernel_calls = "\n ".join( + [ + f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" + for i, (_name, args, global_size, _local_size) in enumerate(statements) + ] + ) + _bufs = "\n ".join( + [ + f"const {name} = " + + ( + f"createEmptyBuf(device, {size});" + if _key not in weight_names + else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))" + ) + + ";" + for name, (size, dtype, _key) in bufs.items() + ] + ) + gpu_write_bufs = "\n ".join( + [ + f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" + for i, input_name in enumerate(input_names) + ] + ) + input_writers = "\n ".join( + [ + f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + + f"_{inp_name});" + + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" + for i, inp_name in enumerate(input_names) + ] + ) + gpu_read_bufs = "\n ".join( + [ + f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" + for i, output_name in enumerate(output_names) + ] + ) + outbuf_copies = "\n ".join( + [ + f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" + for i, output_name in enumerate(output_names) + ] + ) + output_readers = "\n ".join( + [ + f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" + for i in range(len(output_names)) + ] + ) + output_return = "[{}]".format( + ",".join([f"resultBuffer{i}" for i in range(len(output_names))]) + ) + return ( + f""" const getTensorMetadata = (safetensorBuffer) => {{ const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true)); const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength))); @@ -134,46 +247,73 @@ const setupNet = async (device, safetensor) => {{ return {output_return}; }} }} - """ + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}" + """ + + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}" + ) -def export_model(model, target:str, *inputs): - assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported" - run,special_names = jit_model(model, *inputs) - functions, statements, bufs, bufs_to_save = compile_net(run, special_names) - state = get_state_dict(model) - weight_names = {id(x.lazydata.realized): name for name, x in state.items()} - input_names = [name for _,name in special_names.items() if "input" in name] - output_names = [name for _,name in special_names.items() if "output" in name] - prg = "" - if target == "clang": - prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names) - elif target == "webgpu": - prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) - else: - prg = json.dumps({ - "backend": Device.DEFAULT, - "inputs": [{ - "size": bufs[name][0], - "dtype": bufs[name][1].name - } for name in input_names], - "outputs": [{ - "size": bufs[name][0], - "dtype": bufs[name][1].name - } for name in output_names], - "functions": functions, - "statements": [{ - "kernel": kernel, - "args": args, - "global_size": global_size, - "local_size": local_size - } for (kernel, args, global_size, local_size) in statements], - "buffers": { - name: { - "size": size, - "dtype": dtype.name, - "id": weight_names[_key] if _key in weight_names else "" - } for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"] - } - }) - return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state +def export_model(model, target: str, *inputs): + assert ( + Device.DEFAULT in EXPORT_SUPPORTED_DEVICE + ), "only WEBGPU, CLANG, CUDA, GPU, METAL are supported" + run, special_names = jit_model(model, *inputs) + functions, statements, bufs, bufs_to_save = compile_net(run, special_names) + state = get_state_dict(model) + weight_names = {id(x.lazydata.realized): name for name, x in state.items()} + input_names = [name for _, name in special_names.items() if "input" in name] + output_names = [name for _, name in special_names.items() if "output" in name] + prg = "" + if target == "clang": + prg = export_model_clang( + functions, statements, bufs, bufs_to_save, input_names, output_names + ) + elif target == "webgpu": + prg = export_model_webgpu( + functions, + statements, + bufs, + bufs_to_save, + weight_names, + input_names, + output_names, + ) + else: + prg = json.dumps( + { + "backend": Device.DEFAULT, + "inputs": [ + {"size": bufs[name][0], "dtype": bufs[name][1].name} + for name in input_names + ], + "outputs": [ + {"size": bufs[name][0], "dtype": bufs[name][1].name} + for name in output_names + ], + "functions": functions, + "statements": [ + { + "kernel": kernel, + "args": args, + "global_size": global_size, + "local_size": local_size, + } + for (kernel, args, global_size, local_size) in statements + ], + "buffers": { + name: { + "size": size, + "dtype": dtype.name, + "id": weight_names[_key] if _key in weight_names else "", + } + for name, (size, dtype, _key) in bufs.items() + if name not in ["input", "outputs"] + }, + } + ) + + return ( + prg, + {input: bufs[input][0] for input in input_names}, + {output: bufs[output][0] for output in output_names}, + state, + ) diff --git a/extra/gemm/amx.py b/extra/gemm/amx.py index 3058b6639..dc6d5d002 100755 --- a/extra/gemm/amx.py +++ b/extra/gemm/amx.py @@ -2,6 +2,7 @@ import numpy as np import time import sys + np.set_printoptions(linewidth=160) np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False) from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const @@ -11,28 +12,71 @@ from llvmlite import ir # type: ignore # https://github.com/corsix/amx/blob/main/Instructions.md # 12 lines for AMX support from functools import partialmethod + + class AMX: - @staticmethod - def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True) - @staticmethod - def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True) - set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1) - ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3) - ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7) - extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9) - fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13) - mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16) - vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22) + @staticmethod + def nop_op_imm5(op, imm5, builder): + builder.asm( + ir.FunctionType(ir.VoidType(), []), + f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", + "", + tuple(), + True, + ) + + @staticmethod + def op_gpr(op, builder, gpr): + builder.asm( + ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), + f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", + "r", + (gpr,), + True, + ) + + set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1) + ldx, ldy, stx, sty = ( + partialmethod(op_gpr, 0), + partialmethod(op_gpr, 1), + partialmethod(op_gpr, 2), + partialmethod(op_gpr, 3), + ) + ldz, stz, ldzi, stzi = ( + partialmethod(op_gpr, 4), + partialmethod(op_gpr, 5), + partialmethod(op_gpr, 6), + partialmethod(op_gpr, 7), + ) + extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9) + fma64, fms64, fma32, fms32 = ( + partialmethod(op_gpr, 10), + partialmethod(op_gpr, 11), + partialmethod(op_gpr, 12), + partialmethod(op_gpr, 13), + ) + mac16, fma16, fms16 = ( + partialmethod(op_gpr, 14), + partialmethod(op_gpr, 15), + partialmethod(op_gpr, 16), + ) + vecint, vecfp, matint, matfp, genlut = ( + partialmethod(op_gpr, 18), + partialmethod(op_gpr, 19), + partialmethod(op_gpr, 20), + partialmethod(op_gpr, 21), + partialmethod(op_gpr, 22), + ) N = 4096 -#N = 1024 -#N = 64 +# N = 1024 +# N = 64 -#an = np.arange(N*N).reshape(N, N) - 43*64 -#bn = np.arange(N*N).reshape(N, N) -#an = np.ones((N, N)).astype(np.float32) -#bn = np.ones((N, N)).astype(np.float32) +# an = np.arange(N*N).reshape(N, N) - 43*64 +# bn = np.arange(N*N).reshape(N, N) +# an = np.ones((N, N)).astype(np.float32) +# bn = np.ones((N, N)).astype(np.float32) # matrix is 64M, max load bandwidth is 57 GB/s # cache line looks like 256 bytes (64 floats) @@ -49,12 +93,16 @@ cn = (an.T @ bn).T a = LLVMBuffer.fromCPU(an) b = LLVMBuffer.fromCPU(bn) -#c = LLVMBuffer.fromCPU(np.zeros((N, N))) +# c = LLVMBuffer.fromCPU(np.zeros((N, N))) c = LLVMBuffer.fromCPU(np.zeros(256)) -bufs = [c,a,b] +bufs = [c, a, b] module = ir.Module(name=__file__) -func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec') +func = ir.Function( + module, + ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()] * 3), + name="exec", +) # load all entry = ir.IRBuilder(func.append_basic_block(name="entry")) @@ -66,25 +114,42 @@ exit = ir.IRBuilder(func.append_basic_block(name="exit")) y = loop_1.phi(ir.IntType(64), name="y") y.add_incoming(int_const(0), entry._block) -yp = loop_1_exit.add(y, int_const(32*2)) +yp = loop_1_exit.add(y, int_const(32 * 2)) y.add_incoming(yp, loop_1_exit._block) -prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch") +prefetch_function = ir.Function( + module, + ir.FunctionType( + ir.VoidType(), + [ + ir.PointerType(ir.FloatType()), + ir.IntType(32), + ir.IntType(32), + ir.IntType(32), + ], + ), + name="llvm.prefetch", +) xptr = y addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr)) -#prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType())) -#loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)]) +# prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType())) +# loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)]) -AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr)) +AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr)) xptr = loop_1_exit.add(xptr, int_const(32)) -AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr)))) +AMX.ldy( + loop_1_exit, + loop_1_exit.add( + int_const(1 << 62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr)) + ), +) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28)) -AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10)) +AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10)) AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29)) -AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4))) +AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16 * 4))) AMX.set(entry) @@ -93,7 +158,9 @@ AMX.clr(exit) entry.branch(loop_1._block) loop_1.branch(loop_1_exit._block) -loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block) +loop_1_exit.cbranch( + loop_1_exit.icmp_unsigned("==", yp, int_const(N * N)), exit._block, loop_1._block +) exit.ret(int_const(0)) cfunc = LLVM().exec(module, bufs, N**2) @@ -168,21 +235,20 @@ cfunc = LLVM().exec(module, bufs, N**3 * 2) times = [] for i in range(50): - st = time.monotonic() - cfunc(*[x._buf for x in bufs]) - et = time.monotonic() - st - times.append(et) + st = time.monotonic() + cfunc(*[x._buf for x in bufs]) + et = time.monotonic() - st + times.append(et) print(f"{min(times)*1000:.2f} ms min time, {np.median(times)*1000:.2f} ms median time") -print("%.2f GB/s" % ((N*N*4*1e-9)/min(times))) +print("%.2f GB/s" % ((N * N * 4 * 1e-9) / min(times))) -print(c.toCPU().astype(np.int64)[:sn.shape[0]]) +print(c.toCPU().astype(np.int64)[: sn.shape[0]]) print(sn.astype(np.int64)) -np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4) +np.testing.assert_allclose(c.toCPU()[: sn.shape[0]], sn, atol=1e-4, rtol=1e-4) """ print(cn.astype(np.int64)) np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5) """ - diff --git a/extra/gemm/cuda_matmul.py b/extra/gemm/cuda_matmul.py index 1aa089635..164389249 100644 --- a/extra/gemm/cuda_matmul.py +++ b/extra/gemm/cuda_matmul.py @@ -1,5 +1,6 @@ import os import numpy as np + os.environ["CUDA"] = "1" from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda @@ -7,21 +8,24 @@ FLOAT16 = True ACC_FLOAT16 = False N = 4096 -na = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) -nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) +na = np.random.default_rng().standard_normal(size=(N, N), dtype=np.float32) +nb = np.random.default_rng().standard_normal(size=(N, N), dtype=np.float32) if FLOAT16: - na = na.astype(np.float16) - nb = nb.astype(np.float16) + na = na.astype(np.float16) + nb = nb.astype(np.float16) a = RawCUDABuffer.fromCPU(na) b = RawCUDABuffer.fromCPU(nb) -c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32)) +c = RawCUDABuffer.fromCPU(np.ones((N, N), dtype=np.float32)) -FLOPS = N*N*N*2 -BW = N*N*3*4 +FLOPS = N * N * N * 2 +BW = N * N * 3 * 4 -prog = CUDAProgram("wmma_example", compile_cuda(f""" +prog = CUDAProgram( + "wmma_example", + compile_cuda( + f""" #include using namespace nvcuda; @@ -88,10 +92,23 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO }} }} }} -""")) +""" + ), +) -global_size, local_size = [(N//16)//4, (N//16)//4], [32, 1, 1] -tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)]) -print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s") +global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1] +tm = min( + [ + prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) + for _ in range(20) + ] +) +print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s" +) -np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2) \ No newline at end of file +np.testing.assert_allclose( + na.T.astype(np.float32) @ nb.T.astype(np.float32), + c.toCPU().reshape((N, N)).T, + atol=1e-2, +) diff --git a/extra/gemm/fastvits/fastvits_speed.py b/extra/gemm/fastvits/fastvits_speed.py index 9037f1aac..30c935baf 100644 --- a/extra/gemm/fastvits/fastvits_speed.py +++ b/extra/gemm/fastvits/fastvits_speed.py @@ -15,39 +15,50 @@ from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, d from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram from tinygrad.ops import LoadOps, ReduceOps -def single_kernel(): - # single kernel - sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4) - out = CLBuffer(prod(sz1), dtypes.imageh(sz1)) - x = CLBuffer(prod(sz2), dtypes.imageh(sz2)) - w = CLBuffer(prod(sz3), dtypes.imageh(sz3)) - old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read()) - old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)] - print(old_tms, 67.107/min(old_tms)*1e3) - exit(0) +def single_kernel(): + # single kernel + sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4) + out = CLBuffer(prod(sz1), dtypes.imageh(sz1)) + x = CLBuffer(prod(sz2), dtypes.imageh(sz2)) + w = CLBuffer(prod(sz3), dtypes.imageh(sz3)) + + old = CLProgram( + "r_32_16_16_64_4_4_4", + open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read(), + ) + old_tms = [ + old([1, 1, 32], [16, 16, 1], out, x, w, wait=True) * 1e6 for _ in range(5) + ] + print(old_tms, 67.107 / min(old_tms) * 1e3) + exit(0) + # CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py if __name__ == "__main__": - #single_kernel() + # single_kernel() - # this is stage 1 in fastvits - c1 = Conv2d(256, 64, (1,1), bias=False) - c2 = Conv2d(64, 64, (3,3), groups=64, padding=1, bias=False) - c3 = Conv2d(64, 64, (7,7), groups=64, padding=3, bias=False) - c4 = Conv2d(64, 256, (1,1), bias=False) - c5 = Conv2d(256, 64, (1,1), bias=False) + # this is stage 1 in fastvits + c1 = Conv2d(256, 64, (1, 1), bias=False) + c2 = Conv2d(64, 64, (3, 3), groups=64, padding=1, bias=False) + c3 = Conv2d(64, 64, (7, 7), groups=64, padding=3, bias=False) + c4 = Conv2d(64, 256, (1, 1), bias=False) + c5 = Conv2d(256, 64, (1, 1), bias=False) - # TODO: the elementwise ops shouldn't rerun with normal realize - x = Tensor.randn(1, 256, 32, 64) - out = x.sequential([c1,c2,c3,c4,c5]) - schedule = out.lazydata.schedule() + # TODO: the elementwise ops shouldn't rerun with normal realize + x = Tensor.randn(1, 256, 32, 64) + out = x.sequential([c1, c2, c3, c4, c5]) + schedule = out.lazydata.schedule() - schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.get_lazyops())) - run_schedule(schedule_input) - run_schedule(schedule[:getenv("CONV")]) - print("*** init done ***") + schedule, schedule_input = partition( + schedule, + lambda x: x.ast.op not in LoadOps + and any(y.op in ReduceOps for y in x.ast.get_lazyops()), + ) + run_schedule(schedule_input) + run_schedule(schedule[: getenv("CONV")]) + print("*** init done ***") - GlobalCounters.reset() - with Context(DEBUG=getenv("LATEDEBUG", 2), BEAM=getenv("LATEBEAM")): - run_schedule(schedule[getenv("CONV"):getenv("CONV")+1]) + GlobalCounters.reset() + with Context(DEBUG=getenv("LATEDEBUG", 2), BEAM=getenv("LATEBEAM")): + run_schedule(schedule[getenv("CONV") : getenv("CONV") + 1]) diff --git a/extra/gemm/gemm.py b/extra/gemm/gemm.py index f6f13d8f9..a99b33912 100755 --- a/extra/gemm/gemm.py +++ b/extra/gemm/gemm.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 import os -#os.environ['OMP_NUM_THREADS'] = '1' + +# os.environ['OMP_NUM_THREADS'] = '1' import time import numpy as np N = 2048 if __name__ == "__main__": - # N^2 - A = np.random.randn(N, N).astype(np.float32) - # N^2 - B = np.random.randn(N, N).astype(np.float32) + # N^2 + A = np.random.randn(N, N).astype(np.float32) + # N^2 + B = np.random.randn(N, N).astype(np.float32) - # 2N compute in N^2 output cells - flop = 2*N*N*N - #print(f"{flop / 1e9:.2f} GFLOP") + # 2N compute in N^2 output cells + flop = 2 * N * N * N + # print(f"{flop / 1e9:.2f} GFLOP") - for i in range(4): - st = time.monotonic() - C = A @ B.T - et = time.monotonic() - s = et-st - print(f"{flop/s * 1e-9:.2f} GFLOP/S, {s*1e3:.2f} ms") + for i in range(4): + st = time.monotonic() + C = A @ B.T + et = time.monotonic() + s = et - st + print(f"{flop/s * 1e-9:.2f} GFLOP/S, {s*1e3:.2f} ms") - with open("/tmp/matmul", "wb") as f: - f.write(A.data) - f.write(B.data) - f.write(C.data) + with open("/tmp/matmul", "wb") as f: + f.write(A.data) + f.write(B.data) + f.write(C.data) diff --git a/extra/gemm/gemv_845.py b/extra/gemm/gemv_845.py index fffbbd3b2..b9445e710 100644 --- a/extra/gemm/gemv_845.py +++ b/extra/gemm/gemv_845.py @@ -62,21 +62,19 @@ from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram from tinygrad.helpers import dtypes, prod if __name__ == "__main__": - out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4))) - x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4))) - w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4))) - b = CLBuffer(1024, dtypes.float) + out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1, 128, 4))) + x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1, 128, 4))) + w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4))) + b = CLBuffer(1024, dtypes.float) - old = CLProgram("re_S256_16_8", old) - new = CLProgram("r_256_16_4_8_4", new) + old = CLProgram("re_S256_16_8", old) + new = CLProgram("r_256_16_4_8_4", new) - old_tms = [] - new_tms = [] - - for i in range(5): - old_tms.append(old([1,1,256], [4,16,1], out, x, w, b, wait=True)) - new_tms.append(new([256,1,1], [4,16,1], out, x, w, b, wait=True)) - - print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us") + old_tms = [] + new_tms = [] + for i in range(5): + old_tms.append(old([1, 1, 256], [4, 16, 1], out, x, w, b, wait=True)) + new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True)) + print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us") diff --git a/extra/gemm/hip_matmul.py b/extra/gemm/hip_matmul.py index 7d1669638..80b47eac7 100644 --- a/extra/gemm/hip_matmul.py +++ b/extra/gemm/hip_matmul.py @@ -18,24 +18,33 @@ from tinygrad.runtime.ops_hip import HIPAllocator, HIPProgram, compile_hip N = getenv("N", 2048) KX = getenv("KX", 4) KY = getenv("KY", 4) -assert N%(16*KX) == 0, f"N must be multiple of {16*KX}" -assert N%(16*KY) == 0, f"N must be multiple of {16*KY}" -FLOPS = N*N*N*2 -BW = N*N*3*4 +assert N % (16 * KX) == 0, f"N must be multiple of {16*KX}" +assert N % (16 * KY) == 0, f"N must be multiple of {16*KY}" +FLOPS = N * N * N * 2 +BW = N * N * 3 * 4 # Can HIPAllocator initialized as device=0 by default? device = 0 hipallocator = HIPAllocator(device) -a = hipallocator.alloc(N*N*4) -b = hipallocator.alloc(N*N*2) -c = hipallocator.alloc(N*N*2) -na = np.empty(N*N, np.float32) -nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16) -nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16) +a = hipallocator.alloc(N * N * 4) +b = hipallocator.alloc(N * N * 2) +c = hipallocator.alloc(N * N * 2) +na = np.empty(N * N, np.float32) +nb = ( + np.random.default_rng() + .standard_normal(size=(N, N), dtype=np.float32) + .astype(np.float16) +) +nc = ( + np.random.default_rng() + .standard_normal(size=(N, N), dtype=np.float32) + .astype(np.float16) +) hipallocator.copyin(b, bytearray(nb)) hipallocator.copyin(c, bytearray(nc)) -lib = compile_hip(f""" +lib = compile_hip( + f""" #define F32 typedef float float8 __attribute__((ext_vector_type(8))); typedef _Float16 half16 __attribute__((ext_vector_type(16))); @@ -92,22 +101,41 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a, }} }} }} -}}""") +}}""" +) prog = HIPProgram(device, "test", lib) -def timeit(fxn): - st = time.perf_counter() - et = fxn() - ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead - #print(f"{ret*1e6:.2f} us") - return et -global_size, local_size = [N//(KX*16*2), N//(KY*16*2), 1], [32, 2, 2] -print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}") -tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)]) -hipallocator.copyout(flat_mv(na.data),a) -na = na.reshape(N,N) +def timeit(fxn): + st = time.perf_counter() + et = fxn() + ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead + # print(f"{ret*1e6:.2f} us") + return et + + +global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2] +print( + "global/local size", + global_size, + local_size, + f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}", +) +tm = min( + [ + timeit( + lambda: prog( + a, b, c, global_size=global_size, local_size=local_size, wait=True + ) + ) + for _ in range(1000) + ] +) +hipallocator.copyout(flat_mv(na.data), a) +na = na.reshape(N, N) comp = nb.astype(np.float32) @ nc.astype(np.float32) -print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s") +print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s" +) np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2) diff --git a/extra/gemm/jax_pmatmul.py b/extra/gemm/jax_pmatmul.py index b69a2b9b4..903c7bb9e 100755 --- a/extra/gemm/jax_pmatmul.py +++ b/extra/gemm/jax_pmatmul.py @@ -13,15 +13,21 @@ B = jnp.zeros((1, 1, N, N), dtype) A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices()) B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices()) -OPS = DEVICES*BS*N*N*N*2 -def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32) +OPS = DEVICES * BS * N * N * N * 2 + + +def matmul(A, B): + return jnp.matmul(A, B, preferred_element_type=jnp.float32) + + pmatmul = jax.pmap(matmul) -MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX) +MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX) for i in range(10): - st = time.perf_counter() - C = pmatmul(A,B).block_until_ready() - et = time.perf_counter()-st - tflops = (OPS*1e-12)/et - print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}") - + st = time.perf_counter() + C = pmatmul(A, B).block_until_ready() + et = time.perf_counter() - st + tflops = (OPS * 1e-12) / et + print( + f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}" + ) diff --git a/extra/gemm/metal_conv.py b/extra/gemm/metal_conv.py index 9b3df4628..eb22dc1a4 100644 --- a/extra/gemm/metal_conv.py +++ b/extra/gemm/metal_conv.py @@ -1,5 +1,6 @@ import os -#os.environ["METAL"] = "1" + +# os.environ["METAL"] = "1" import numpy as np BS = 64 @@ -11,39 +12,48 @@ PADDING = 0 # TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS. # are winograd convs less flops? it appears so if they are batched # https://www.cse.ust.hk/~weiwa/papers/yan-ppopp20.pdf -FLOPS = BS*K*K*CIN*HW*HW*COUT*2 +FLOPS = BS * K * K * CIN * HW * HW * COUT * 2 -nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32) -nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32) +nb = np.random.default_rng().standard_normal(size=(BS, CIN, HW, HW), dtype=np.float32) +nc = np.random.default_rng().standard_normal(size=(COUT, CIN, K, K), dtype=np.float32) try: - import time, torch, torch.mps - b = torch.from_numpy(nb).to('mps') - c = torch.from_numpy(nc).to('mps') + import time, torch, torch.mps - def torch_prog(b, c): - st = time.perf_counter() - a = torch.nn.functional.conv2d(b, c, padding=PADDING) - torch.mps.synchronize() - return time.perf_counter() - st - tm = min([torch_prog(b, c) for _ in range(20)]) - print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch") + b = torch.from_numpy(nb).to("mps") + c = torch.from_numpy(nc).to("mps") + + def torch_prog(b, c): + st = time.perf_counter() + a = torch.nn.functional.conv2d(b, c, padding=PADDING) + torch.mps.synchronize() + return time.perf_counter() - st + + tm = min([torch_prog(b, c) for _ in range(20)]) + print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch") except RuntimeError: - print("no torch metal conv") + print("no torch metal conv") from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad import Device + b = Tensor(nb) c = Tensor(nc) + + # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator @TinyJit def tiny_jit(b, c): - return b.conv2d(c, padding=PADDING).realize() + return b.conv2d(c, padding=PADDING).realize() + + def tiny_prog(b, c): - st = time.perf_counter() - a = tiny_jit(b, c) - Device[a.device].synchronize() - return time.perf_counter() - st + st = time.perf_counter() + a = tiny_jit(b, c) + Device[a.device].synchronize() + return time.perf_counter() - st + + tm = min([tiny_prog(b, c) for _ in range(5)]) print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad") diff --git a/extra/gemm/metal_matmul.py b/extra/gemm/metal_matmul.py index 6f1899890..ee850615a 100644 --- a/extra/gemm/metal_matmul.py +++ b/extra/gemm/metal_matmul.py @@ -1,4 +1,5 @@ import os + os.environ["METAL"] = "1" import time import numpy as np @@ -8,17 +9,24 @@ from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_met N = getenv("N", 2048) LID = 2 -a = RawMetalBuffer(N*N, dtypes.float32) +a = RawMetalBuffer(N * N, dtypes.float32) -nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) -nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) +nb = np.random.default_rng().standard_normal( + size=(N, N), dtype=np.float32 +) # .astype(np.int32).astype(np.float32) +nc = np.random.default_rng().standard_normal( + size=(N, N), dtype=np.float32 +) # .astype(np.int32).astype(np.float32) b = RawMetalBuffer.fromCPU(nb) c = RawMetalBuffer.fromCPU(nc) -FLOPS = N*N*N*2 -BW = N*N*3*4 +FLOPS = N * N * N * 2 +BW = N * N * 3 * 4 -prog = MetalProgram("test", compile_metal(f""" +prog = MetalProgram( + "test", + compile_metal( + f""" #include #include // Available from Metal version 2.3 released with OS X 11.0+ using namespace metal; @@ -80,46 +88,83 @@ kernel void test(device float *a, device const float *data1, device const float simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0)); simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0)); simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0)); -}}""")) +}}""" + ), +) + + def timeit(fxn): - st = time.perf_counter() - et = fxn() - # NOTE: et doesn't contain the launch overhead - return time.perf_counter() - st -tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)]) -na = a.toCPU().reshape(N,N) -comp = nb@nc + st = time.perf_counter() + et = fxn() + # NOTE: et doesn't contain the launch overhead + return time.perf_counter() - st + + +tm = min( + [ + timeit( + lambda: prog( + a, + b, + c, + global_size=[N // (8 * 4), N // (8 * 4 * LID), 1], + local_size=[32, LID, 1], + wait=True, + ) + ) + for _ in range(20) + ] +) +na = a.toCPU().reshape(N, N) +comp = nb @ nc if N <= 32: - print(na) - print(comp) -print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s") + print(na) + print(comp) +print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s" +) np.testing.assert_allclose(na, comp, atol=1e-3) import torch, torch.mps -b = torch.from_numpy(nb).to('mps') -c = torch.from_numpy(nc).to('mps') + +b = torch.from_numpy(nb).to("mps") +c = torch.from_numpy(nc).to("mps") + def torch_prog(b, c): - st = time.perf_counter() - a = b@c - torch.mps.synchronize() - return time.perf_counter() - st + st = time.perf_counter() + a = b @ c + torch.mps.synchronize() + return time.perf_counter() - st + + tm = min([torch_prog(b, c) for _ in range(20)]) -print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch") +print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch" +) from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.runtime.ops_metal import METAL + b = Tensor(nb) c = Tensor(nc) + + # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator @TinyJit def tiny_jit(b, c): - return (b@c).realize() + return (b @ c).realize() + + def tiny_prog(b, c): - st = time.perf_counter() - a = tiny_jit(b, c) - METAL.synchronize() - return time.perf_counter() - st + st = time.perf_counter() + a = tiny_jit(b, c) + METAL.synchronize() + return time.perf_counter() - st + + tm = min([tiny_prog(b, c) for _ in range(20)]) -print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad") +print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad" +) diff --git a/extra/gemm/metal_matvec.py b/extra/gemm/metal_matvec.py index 60df010d4..c9c2a2bc7 100644 --- a/extra/gemm/metal_matvec.py +++ b/extra/gemm/metal_matvec.py @@ -1,5 +1,6 @@ import os -#os.environ["METAL"] = "1" + +# os.environ["METAL"] = "1" import numpy as np import time, torch, torch.mps @@ -10,6 +11,7 @@ from tinygrad import Device from tinygrad.helpers import colored, getenv, CI import os + os.environ["METAL"] = "1" import time import numpy as np @@ -18,29 +20,40 @@ from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_met N = 16384 M = 4096 -FLOPS = N*M*2 +FLOPS = N * M * 2 -nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32) -nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32) +nb = np.random.default_rng().standard_normal( + size=(N), dtype=np.float32 +) # .astype(np.int32).astype(np.float32) +nc = np.random.default_rng().standard_normal( + size=(N, M), dtype=np.float32 +) # .astype(np.int32).astype(np.float32) import torch, torch.mps -b = torch.from_numpy(nb).to('mps') -c = torch.from_numpy(nc).to('mps') + +b = torch.from_numpy(nb).to("mps") +c = torch.from_numpy(nc).to("mps") + def torch_prog(b, c): - st = time.perf_counter() - a = b@c - torch.mps.synchronize() - return time.perf_counter() - st + st = time.perf_counter() + a = b @ c + torch.mps.synchronize() + return time.perf_counter() - st + + tm = min([torch_prog(b, c) for _ in range(200)]) -print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch") -torch_a = (b@c).cpu() +print( + f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch" +) +torch_a = (b @ c).cpu() WORKSIZE_ROW = 16 WORKSIZE_COL = 1 LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW] -GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1] -prog = compile_metal(f""" +GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1] +prog = compile_metal( + f""" #include using namespace metal; kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{ @@ -86,41 +99,59 @@ kernel void test(device float* data0, const device float* data1, const device fl *( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out; }} }} -""") +""" +) prog = MetalProgram("test", prog) # print(prog_string) na = np.zeros(M, dtype=np.float32) b = RawMetalBuffer.fromCPU(nb) c = RawMetalBuffer.fromCPU(nc) + + def metalrun(): - a = RawMetalBuffer.fromCPU(na) - prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True) - return a + a = RawMetalBuffer.fromCPU(na) + prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True) + return a + + def timeit(fxn): - st = time.perf_counter() - et = fxn() - # NOTE: et doesn't contain the launch overhead - return time.perf_counter() - st + st = time.perf_counter() + et = fxn() + # NOTE: et doesn't contain the launch overhead + return time.perf_counter() - st + + tm = min([timeit(metalrun) for _ in range(200)]) -print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal") +print( + f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal" +) metal_a = metalrun().toCPU().reshape(M) np.testing.assert_allclose(metal_a, torch_a, atol=5e-3) from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.runtime.ops_metal import METAL + b = Tensor(nb) c = Tensor(nc) + + # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator @TinyJit def tiny_jit(b, c): - return (b@c).realize() + return (b @ c).realize() + + def tiny_prog(b, c): - st = time.perf_counter() - a = tiny_jit(b, c) - METAL.synchronize() - return time.perf_counter() - st + st = time.perf_counter() + a = tiny_jit(b, c) + METAL.synchronize() + return time.perf_counter() - st + + tm = min([tiny_prog(b, c) for _ in range(200)]) -print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad") +print( + f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad" +) tiny_a = tiny_jit(b, c).numpy() -np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3) \ No newline at end of file +np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 825089a9e..b2ed37903 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -2,14 +2,28 @@ import numpy as np from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes + dtype_in = dtypes.half if getenv("HALF") else dtypes.float N = getenv("N", 4096) CNT = getenv("CNT", 10) -a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() +a, b = ( + Tensor.rand(N, N, dtype=dtype_in).realize(), + Tensor.rand(N, N, dtype=dtype_in).realize(), +) for i in range(CNT): - if i > 0 and getenv("RAND", 0) != 0: - a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() - c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize() + if i > 0 and getenv("RAND", 0) != 0: + a, b = ( + Tensor.rand(N, N, dtype=dtype_in).realize(), + Tensor.rand(N, N, dtype=dtype_in).realize(), + ) + c = ( + (a.reshape(N, 1, N) * b.permute(1, 0).reshape(1, N, N)) + .float() + .sum(axis=2) + .realize() + if getenv("ACCUM_FP32") + else (a @ b).realize() + ) comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2) diff --git a/extra/gemm/tf_gemm.py b/extra/gemm/tf_gemm.py index 802b34435..180759f02 100644 --- a/extra/gemm/tf_gemm.py +++ b/extra/gemm/tf_gemm.py @@ -1,33 +1,37 @@ import time import tensorflow as tf -gpus = tf.config.list_physical_devices('GPU') +gpus = tf.config.list_physical_devices("GPU") if gpus: - try: - # Currently, memory growth needs to be the same across GPUs - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - logical_gpus = tf.config.list_logical_devices('GPU') - print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - except RuntimeError as e: - # Memory growth must be set before GPUs have been initialized - print(e) + try: + # Currently, memory growth needs to be the same across GPUs + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.list_logical_devices("GPU") + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + except RuntimeError as e: + # Memory growth must be set before GPUs have been initialized + print(e) for dtype in [tf.float16, tf.float32]: - for N in [256, 512, 1024, 2048, 4096, 8192]: - FLOPS = N*N*N*2 + for N in [256, 512, 1024, 2048, 4096, 8192]: + FLOPS = N * N * N * 2 - b = tf.random.uniform((N, N), dtype=dtype) - c = tf.random.uniform((N, N), dtype=dtype) + b = tf.random.uniform((N, N), dtype=dtype) + c = tf.random.uniform((N, N), dtype=dtype) - b = tf.Variable(b) - c = tf.Variable(c) + b = tf.Variable(b) + c = tf.Variable(c) - def tf_prog(b, c): - st = time.perf_counter() - a = tf.matmul(b, c) - tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done. - return time.perf_counter() - st + def tf_prog(b, c): + st = time.perf_counter() + a = tf.matmul(b, c) + tf.debugging.check_numerics( + a, "Nan or Inf in result" + ) # Ensures that the calculation is done. + return time.perf_counter() - st - tm = min([tf_prog(b, c) for _ in range(20)]) - print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}") \ No newline at end of file + tm = min([tf_prog(b, c) for _ in range(20)]) + print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}" + ) diff --git a/extra/gemm/torch_gemm.py b/extra/gemm/torch_gemm.py index a87f2757e..31ad6fe4c 100644 --- a/extra/gemm/torch_gemm.py +++ b/extra/gemm/torch_gemm.py @@ -2,16 +2,19 @@ import time import torch for dtype in [torch.float16, torch.float32]: - for N in [256, 512, 1024, 2048, 4096]: - FLOPS = N*N*N*2 + for N in [256, 512, 1024, 2048, 4096]: + FLOPS = N * N * N * 2 - b = torch.rand((N,N), dtype=dtype).cuda() - c = torch.rand((N,N), dtype=dtype).cuda() + b = torch.rand((N, N), dtype=dtype).cuda() + c = torch.rand((N, N), dtype=dtype).cuda() - def torch_prog(b, c): - st = time.perf_counter() - a = b@c - torch.cuda.synchronize() - return time.perf_counter() - st - tm = min([torch_prog(b, c) for _ in range(20)]) - print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}") + def torch_prog(b, c): + st = time.perf_counter() + a = b @ c + torch.cuda.synchronize() + return time.perf_counter() - st + + tm = min([torch_prog(b, c) for _ in range(20)]) + print( + f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}" + ) diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py index 11318a348..f594b8323 100644 --- a/extra/gemm/tvm_gemm.py +++ b/extra/gemm/tvm_gemm.py @@ -3,28 +3,29 @@ M, N, K = 1024, 1024, 1024 try: - import tvm - from tvm import te - #print(tvm.target.Target.list_kinds()) + import tvm + from tvm import te - # c, opencl - target = tvm.target.Target(target="c") + # print(tvm.target.Target.list_kinds()) - # TVM Matrix Multiplication using TE - k = te.reduce_axis((0, K), "k") - A = te.placeholder((M, K), name="A") - B = te.placeholder((K, N), name="B") - C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C") + # c, opencl + target = tvm.target.Target(target="c") - # Default schedule - s = te.create_schedule(C.op) - #print(tvm.lower(s, [A, B, C], simple_mode=True)) + # TVM Matrix Multiplication using TE + k = te.reduce_axis((0, K), "k") + A = te.placeholder((M, K), name="A") + B = te.placeholder((K, N), name="B") + C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C") - # Output C code - func = tvm.build(s, [A, B, C], target=target, name="mmult") - print(func.get_source()) + # Default schedule + s = te.create_schedule(C.op) + # print(tvm.lower(s, [A, B, C], simple_mode=True)) + + # Output C code + func = tvm.build(s, [A, B, C], target=target, name="mmult") + print(func.get_source()) except ImportError: - print("** please install TVM for TVM output") + print("** please install TVM for TVM output") # tinygrad version @@ -34,14 +35,18 @@ from tinygrad.tensor import Tensor # define the compute A = Tensor.rand(M, K, device="clang") B = Tensor.rand(K, N, device="clang") -C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2) +C = (A.reshape(M, 1, K) * B.permute(1, 0).reshape(1, N, K)).sum(axis=2) sched = C.lazydata.schedule() from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.kernel import LinearizerOptions -lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)) -#lin.hand_coded_optimizations() + +lin = Linearizer( + sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False) +) +# lin.hand_coded_optimizations() lin.linearize() from tinygrad.runtime.ops_clang import renderer + src = renderer("mmult", lin.uops) print(src) diff --git a/extra/gradcheck.py b/extra/gradcheck.py index 4d99726cc..545f11ff7 100644 --- a/extra/gradcheck.py +++ b/extra/gradcheck.py @@ -1,50 +1,58 @@ import numpy as np from tinygrad.tensor import Tensor -def mask_like(like, mask_inx, mask_value = 1.0): - mask = np.zeros_like(like).reshape(-1) - mask[mask_inx] = mask_value - return mask.reshape(like.shape) + +def mask_like(like, mask_inx, mask_value=1.0): + mask = np.zeros_like(like).reshape(-1) + mask[mask_inx] = mask_value + return mask.reshape(like.shape) + def jacobian(func, input): - output = func(input) - - ji = input.numpy().reshape(-1).shape[-1] - jo = output.numpy().reshape(-1).shape[-1] - J = np.zeros((jo,ji), dtype=np.float32) - - for o in range(jo): - input.grad = None output = func(input) - # tinygrad doesn't support slicing, tiny-hack to select - # the needed scalar an backpropagate only through it - o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum() - o_scalar.backward() + ji = input.numpy().reshape(-1).shape[-1] + jo = output.numpy().reshape(-1).shape[-1] + J = np.zeros((jo, ji), dtype=np.float32) - for i, grad in enumerate(input.grad.numpy().reshape(-1)): - J[o,i] = grad - return J + for o in range(jo): + input.grad = None + output = func(input) -def numerical_jacobian(func, input, eps = 1e-3): - output = func(input) + # tinygrad doesn't support slicing, tiny-hack to select + # the needed scalar an backpropagate only through it + o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum() + o_scalar.backward() - ji = input.numpy().reshape(-1).shape[-1] - jo = output.numpy().reshape(-1).shape[-1] - NJ = np.zeros((jo, ji), dtype=np.float32) + for i, grad in enumerate(input.grad.numpy().reshape(-1)): + J[o, i] = grad + return J - for i in range(ji): - eps_perturb = mask_like(input.numpy(), i, mask_value = eps) - output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1) - output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1) +def numerical_jacobian(func, input, eps=1e-3): + output = func(input) - grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps) + ji = input.numpy().reshape(-1).shape[-1] + jo = output.numpy().reshape(-1).shape[-1] + NJ = np.zeros((jo, ji), dtype=np.float32) - NJ[:,i] = grad_approx - return NJ + for i in range(ji): + eps_perturb = mask_like(input.numpy(), i, mask_value=eps) -def gradcheck(func, input, eps = 1e-3, atol = 1e-3, rtol = 1e-3): - NJ = numerical_jacobian(func, input, eps) - J = jacobian(func, input) - return np.allclose(J, NJ, atol = atol, rtol = rtol) + output_perturb_add = ( + func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1) + ) + output_perturb_sub = ( + func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1) + ) + + grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps) + + NJ[:, i] = grad_approx + return NJ + + +def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3): + NJ = numerical_jacobian(func, input, eps) + J = jacobian(func, input) + return np.allclose(J, NJ, atol=atol, rtol=rtol) diff --git a/extra/helpers.py b/extra/helpers.py index c580a1a69..685f68301 100644 --- a/extra/helpers.py +++ b/extra/helpers.py @@ -2,49 +2,71 @@ import multiprocessing, subprocess import cloudpickle from typing import Any + def _early_exec_process(qin, qout): - while True: - path, inp = qin.get() - try: - qout.put(subprocess.check_output(path, input=inp)) - except Exception as e: - qout.put(e) + while True: + path, inp = qin.get() + try: + qout.put(subprocess.check_output(path, input=inp)) + except Exception as e: + qout.put(e) + def enable_early_exec(): - qin: multiprocessing.Queue = multiprocessing.Queue() - qout: multiprocessing.Queue = multiprocessing.Queue() - p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout)) - p.daemon = True - p.start() - def early_exec(x): - qin.put(x) - ret = qout.get() - if isinstance(ret, Exception): raise ret - else: return ret - return early_exec + qin: multiprocessing.Queue = multiprocessing.Queue() + qout: multiprocessing.Queue = multiprocessing.Queue() + p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout)) + p.daemon = True + p.start() + + def early_exec(x): + qin.put(x) + ret = qout.get() + if isinstance(ret, Exception): + raise ret + else: + return ret + + return early_exec + def proc(itermaker, q) -> None: - try: - for x in itermaker(): q.put(x) - except Exception as e: - q.put(e) - finally: - q.put(None) - q.close() + try: + for x in itermaker(): + q.put(x) + except Exception as e: + q.put(e) + finally: + q.put(None) + q.close() + class _CloudpickleFunctionWrapper: - def __init__(self, fn): self.fn = fn - def __getstate__(self): return cloudpickle.dumps(self.fn) - def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn) - def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs) + def __init__(self, fn): + self.fn = fn + + def __getstate__(self): + return cloudpickle.dumps(self.fn) + + def __setstate__(self, pfn): + self.fn = cloudpickle.loads(pfn) + + def __call__(self, *args, **kwargs) -> Any: + return self.fn(*args, **kwargs) + def cross_process(itermaker, maxsize=16): - q: multiprocessing.Queue = multiprocessing.Queue(maxsize) - # multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle. - p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q)) - p.start() - while True: - ret = q.get() - if isinstance(ret, Exception): raise ret - elif ret is None: break - else: yield ret \ No newline at end of file + q: multiprocessing.Queue = multiprocessing.Queue(maxsize) + # multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle. + p = multiprocessing.Process( + target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q) + ) + p.start() + while True: + ret = q.get() + if isinstance(ret, Exception): + raise ret + elif ret is None: + break + else: + yield ret diff --git a/extra/introspection.py b/extra/introspection.py index 5c791ea7b..9a6f57384 100644 --- a/extra/introspection.py +++ b/extra/introspection.py @@ -6,37 +6,45 @@ from tinygrad.lazy import LazyBuffer from tinygrad.runtime.ops_gpu import CLBuffer from tinygrad.helpers import GlobalCounters + def print_objects(): - #gc.collect() - tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)] - tensor_ram_used = sum([prod(x.shape)*4 for x in tensors]) - lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)] - gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)] - realized_buffers = [x.realized for x in lazybuffers if x.realized] - gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers] + # gc.collect() + tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)] + tensor_ram_used = sum([prod(x.shape) * 4 for x in tensors]) + lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)] + gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)] + realized_buffers = [x.realized for x in lazybuffers if x.realized] + gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers] - print(f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB") - print(f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers") - print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned") + print( + f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB" + ) + print( + f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers" + ) + print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned") - cnt = 0 - for tb in gpubuffers_orphaned: - bb = gc.get_referrers(tb) - for b in bb: - if b is not gpubuffers and b is not gpubuffers_orphaned: - print(tb, "\nreference", type(b), len(b), str(b)[0:150]) - for x in gc.get_referrers(b): - print("double reference", str(x)[0:100]) - print("\n") - if cnt == 10: - break - cnt += 1 + cnt = 0 + for tb in gpubuffers_orphaned: + bb = gc.get_referrers(tb) + for b in bb: + if b is not gpubuffers and b is not gpubuffers_orphaned: + print(tb, "\nreference", type(b), len(b), str(b)[0:150]) + for x in gc.get_referrers(b): + print("double reference", str(x)[0:100]) + print("\n") + if cnt == 10: + break + cnt += 1 - for x in gpubuffers_orphaned: - if getattr(x, '_buf', None): del x._buf - if getattr(x, '_image', None): del x._image + for x in gpubuffers_orphaned: + if getattr(x, "_buf", None): + del x._buf + if getattr(x, "_image", None): + del x._image + + return len(gpubuffers_orphaned) - return len(gpubuffers_orphaned) """ import gc diff --git a/extra/junk/sentencepiece_model_pb2.py b/extra/junk/sentencepiece_model_pb2.py index 5de978fad..c500f8322 100644 --- a/extra/junk/sentencepiece_model_pb2.py +++ b/extra/junk/sentencepiece_model_pb2.py @@ -7,39 +7,44 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: - _globals['DESCRIPTOR']._options = None - _globals['DESCRIPTOR']._serialized_options = b'H\003' - _globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None - _globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001' - _globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None - _globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001' - _globals['_TRAINERSPEC']._serialized_start=45 - _globals['_TRAINERSPEC']._serialized_end=1581 - _globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517 - _globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570 - _globals['_NORMALIZERSPEC']._serialized_start=1584 - _globals['_NORMALIZERSPEC']._serialized_end=1793 - _globals['_SELFTESTDATA']._serialized_start=1795 - _globals['_SELFTESTDATA']._serialized_end=1916 - _globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864 - _globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905 - _globals['_MODELPROTO']._serialized_start=1919 - _globals['_MODELPROTO']._serialized_end=2429 - _globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208 - _globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418 - _globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323 - _globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407 + _globals["DESCRIPTOR"]._options = None + _globals["DESCRIPTOR"]._serialized_options = b"H\003" + _globals["_TRAINERSPEC"].fields_by_name["mining_sentence_size"]._options = None + _globals["_TRAINERSPEC"].fields_by_name[ + "mining_sentence_size" + ]._serialized_options = b"\030\001" + _globals["_TRAINERSPEC"].fields_by_name["training_sentence_size"]._options = None + _globals["_TRAINERSPEC"].fields_by_name[ + "training_sentence_size" + ]._serialized_options = b"\030\001" + _globals["_TRAINERSPEC"]._serialized_start = 45 + _globals["_TRAINERSPEC"]._serialized_end = 1581 + _globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517 + _globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570 + _globals["_NORMALIZERSPEC"]._serialized_start = 1584 + _globals["_NORMALIZERSPEC"]._serialized_end = 1793 + _globals["_SELFTESTDATA"]._serialized_start = 1795 + _globals["_SELFTESTDATA"]._serialized_end = 1916 + _globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864 + _globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905 + _globals["_MODELPROTO"]._serialized_start = 1919 + _globals["_MODELPROTO"]._serialized_end = 2429 + _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208 + _globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418 + _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323 + _globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407 # @@protoc_insertion_point(module_scope) diff --git a/extra/lr_scheduler.py b/extra/lr_scheduler.py index a1bb93165..923b01ecf 100644 --- a/extra/lr_scheduler.py +++ b/extra/lr_scheduler.py @@ -3,84 +3,138 @@ from typing import List from tinygrad.nn.optim import Optimizer from tinygrad.tensor import Tensor + class LR_Scheduler: - def __init__(self, optimizer: Optimizer): - self.optimizer = optimizer - self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device) + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + self.epoch_counter = Tensor( + [0], requires_grad=False, device=self.optimizer.device + ) - def get_lr(self): pass + def get_lr(self): + pass + + def step(self) -> None: + self.epoch_counter.assign(self.epoch_counter + 1).realize() + self.optimizer.lr.assign(self.get_lr()).realize() - def step(self) -> None: - self.epoch_counter.assign(self.epoch_counter + 1).realize() - self.optimizer.lr.assign(self.get_lr()).realize() class MultiStepLR(LR_Scheduler): - def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1): - super().__init__(optimizer) - self.milestones = milestones - self.gamma = gamma + def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1): + super().__init__(optimizer) + self.milestones = milestones + self.gamma = gamma + + def get_lr(self) -> Tensor: + if self.epoch_counter.numpy()[0] not in self.milestones: + return self.optimizer.lr + return self.optimizer.lr * self.gamma - def get_lr(self) -> Tensor: - if self.epoch_counter.numpy()[0] not in self.milestones: - return self.optimizer.lr - return self.optimizer.lr * self.gamma class ReduceLROnPlateau(LR_Scheduler): - def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"): - assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"] - super().__init__(optimizer) - self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode - self.best = float('inf') if mode == "min" else float('-inf') - self.bad_epoch = 0 + def __init__( + self, + optimizer: Optimizer, + mode="min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode="rel", + ): + assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"] + super().__init__(optimizer) + self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = ( + mode, + factor, + patience, + threshold, + threshold_mode, + ) + self.best = float("inf") if mode == "min" else float("-inf") + self.bad_epoch = 0 - if mode == "min": self.threshold *= -1 + if mode == "min": + self.threshold *= -1 - def is_better(self, current: float) -> bool: - dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold - if self.mode == "min": - return current < dynamic_threshold - return current > dynamic_threshold + def is_better(self, current: float) -> bool: + dynamic_threshold = ( + self.best * (1 + self.threshold) + if self.threshold_mode == "rel" + else self.best + self.threshold + ) + if self.mode == "min": + return current < dynamic_threshold + return current > dynamic_threshold - def step(self, current: float) -> None: - self.epoch_counter.assign(self.epoch_counter + 1).realize() - if self.is_better(current): - self.bad_epoch = 0 - self.best = current - else: - self.bad_epoch += 1 + def step(self, current: float) -> None: + self.epoch_counter.assign(self.epoch_counter + 1).realize() + if self.is_better(current): + self.bad_epoch = 0 + self.best = current + else: + self.bad_epoch += 1 + + if self.bad_epoch > self.patience: + self.optimizer.lr *= self.factor + self.bad_epoch = 0 - if self.bad_epoch > self.patience: - self.optimizer.lr *= self.factor - self.bad_epoch = 0 class CosineAnnealingLR(LR_Scheduler): - def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0): - super().__init__(optimizer) - self.T_max = T_max - self.eta_min = eta_min - self.eta_max = optimizer.lr.numpy()[0] + def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0): + super().__init__(optimizer) + self.T_max = T_max + self.eta_min = eta_min + self.eta_max = optimizer.lr.numpy()[0] + + def get_lr(self) -> Tensor: + return Tensor( + [ + self.eta_min + + 0.5 + * (self.eta_max - self.eta_min) + * (1 + math.cos((self.epoch_counter.numpy()[0] / self.T_max) * math.pi)) + ], + device=self.optimizer.device, + ) - def get_lr(self) -> Tensor: - return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device) class OneCycleLR(LR_Scheduler): - def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float, - anneal_strategy: str = 'linear', cycle_momentum: bool = False): - self.initial_lr = Tensor([max_lr / div_factor]).contiguous() - self.max_lr = Tensor([max_lr]).contiguous() - self.min_lr = self.initial_lr/final_div_factor - super().__init__(optimizer) - self.total_steps = total_steps - self.pct_start = pct_start - assert anneal_strategy == 'linear', 'only linear annealing supported' - assert not cycle_momentum, 'cycle momentum not supported' - self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR + def __init__( + self, + optimizer: Optimizer, + max_lr: float, + div_factor: float, + final_div_factor: float, + total_steps: int, + pct_start: float, + anneal_strategy: str = "linear", + cycle_momentum: bool = False, + ): + self.initial_lr = Tensor([max_lr / div_factor]).contiguous() + self.max_lr = Tensor([max_lr]).contiguous() + self.min_lr = self.initial_lr / final_div_factor + super().__init__(optimizer) + self.total_steps = total_steps + self.pct_start = pct_start + assert anneal_strategy == "linear", "only linear annealing supported" + assert not cycle_momentum, "cycle momentum not supported" + self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR - @staticmethod - def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: return ((end - start) * pct + start) + @staticmethod + def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: + return (end - start) * pct + start - def get_lr(self) -> Tensor: - return (self.epoch_counter < self.total_steps*self.pct_start).where( - self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)), - self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start))) - ) + def get_lr(self) -> Tensor: + return (self.epoch_counter < self.total_steps * self.pct_start).where( + self._annealing_linear( + self.initial_lr, + self.max_lr, + self.epoch_counter / (self.total_steps * self.pct_start), + ), + self._annealing_linear( + self.max_lr, + self.min_lr, + (self.epoch_counter - (self.total_steps * self.pct_start)) + / (self.total_steps * (1 - self.pct_start)), + ), + ) diff --git a/extra/models/bert.py b/extra/models/bert.py index dd155126c..441d8be5d 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -5,167 +5,290 @@ from pathlib import Path class BertForQuestionAnswering: - def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1): - self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob) - self.qa_outputs = Linear(hidden_size, 2) + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + max_position_embeddings=512, + num_attention_heads=16, + num_hidden_layers=24, + type_vocab_size=2, + vocab_size=30522, + attention_probs_dropout_prob=0.1, + hidden_dropout_prob=0.1, + ): + self.bert = Bert( + hidden_size, + intermediate_size, + max_position_embeddings, + num_attention_heads, + num_hidden_layers, + type_vocab_size, + vocab_size, + attention_probs_dropout_prob, + hidden_dropout_prob, + ) + self.qa_outputs = Linear(hidden_size, 2) - def load_from_pretrained(self): - fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt" - fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn) - fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt" - fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab) + def load_from_pretrained(self): + fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt" + fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn) + fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt" + fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab) - import torch - with open(fn, "rb") as f: - state_dict = torch.load(f, map_location="cpu") + import torch - for k, v in state_dict.items(): - if "dropout" in k: continue # skip dropout - if "pooler" in k: continue # skip pooler - get_child(self, k).assign(v.numpy()).realize() + with open(fn, "rb") as f: + state_dict = torch.load(f, map_location="cpu") - def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor): - sequence_output = self.bert(input_ids, attention_mask, token_type_ids) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.chunk(2, dim=-1) - start_logits = start_logits.reshape(-1, 1) - end_logits = end_logits.reshape(-1, 1) + for k, v in state_dict.items(): + if "dropout" in k: + continue # skip dropout + if "pooler" in k: + continue # skip pooler + get_child(self, k).assign(v.numpy()).realize() + + def __call__( + self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor + ): + sequence_output = self.bert(input_ids, attention_mask, token_type_ids) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.chunk(2, dim=-1) + start_logits = start_logits.reshape(-1, 1) + end_logits = end_logits.reshape(-1, 1) + + return Tensor.stack([start_logits, end_logits]) - return Tensor.stack([start_logits, end_logits]) class Bert: - def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob): - self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob) - self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob) + def __init__( + self, + hidden_size, + intermediate_size, + max_position_embeddings, + num_attention_heads, + num_hidden_layers, + type_vocab_size, + vocab_size, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + self.embeddings = BertEmbeddings( + hidden_size, + max_position_embeddings, + type_vocab_size, + vocab_size, + hidden_dropout_prob, + ) + self.encoder = BertEncoder( + hidden_size, + intermediate_size, + num_attention_heads, + num_hidden_layers, + attention_probs_dropout_prob, + hidden_dropout_prob, + ) - def __call__(self, input_ids, attention_mask, token_type_ids): - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + def __call__(self, input_ids, attention_mask, token_type_ids): + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - embedding_output = self.embeddings(input_ids, token_type_ids) - encoder_outputs = self.encoder(embedding_output, extended_attention_mask) + embedding_output = self.embeddings(input_ids, token_type_ids) + encoder_outputs = self.encoder(embedding_output, extended_attention_mask) + + return encoder_outputs - return encoder_outputs class BertEmbeddings: - def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob): - self.word_embeddings = Embedding(vocab_size, hidden_size) - self.position_embeddings = Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = Embedding(type_vocab_size, hidden_size) - self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) - self.dropout = hidden_dropout_prob + def __init__( + self, + hidden_size, + max_position_embeddings, + type_vocab_size, + vocab_size, + hidden_dropout_prob, + ): + self.word_embeddings = Embedding(vocab_size, hidden_size) + self.position_embeddings = Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = Embedding(type_vocab_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob - def __call__(self, input_ids, token_type_ids): - input_shape = input_ids.shape - seq_length = input_shape[1] + def __call__(self, input_ids, token_type_ids): + input_shape = input_ids.shape + seq_length = input_shape[1] - position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) + position_ids = ( + Tensor.arange(seq_length, requires_grad=False) + .unsqueeze(0) + .expand(*input_shape) + ) + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = embeddings.dropout(self.dropout) + return embeddings - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = embeddings.dropout(self.dropout) - return embeddings class BertEncoder: - def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob): - self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)] + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + num_hidden_layers, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + self.layer = [ + BertLayer( + hidden_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout_prob, + hidden_dropout_prob, + ) + for _ in range(num_hidden_layers) + ] + + def __call__(self, hidden_states, attention_mask): + for layer in self.layer: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states - def __call__(self, hidden_states, attention_mask): - for layer in self.layer: - hidden_states = layer(hidden_states, attention_mask) - return hidden_states class BertLayer: - def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): - self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) - self.intermediate = BertIntermediate(hidden_size, intermediate_size) - self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + self.attention = BertAttention( + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + hidden_dropout_prob, + ) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output - def __call__(self, hidden_states, attention_mask): - attention_output = self.attention(hidden_states, attention_mask) - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output class BertOutput: - def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): - self.dense = Linear(intermediate_size, hidden_size) - self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) - self.dropout = hidden_dropout_prob + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + self.dense = Linear(intermediate_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob + + def __call__(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = hidden_states.dropout(self.dropout) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states - def __call__(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = hidden_states.dropout(self.dropout) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states # approximation of the error function def erf(x): - t = (1 + 0.3275911 * x.abs()).reciprocal() - return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp()) + t = (1 + 0.3275911 * x.abs()).reciprocal() + return x.sign() * ( + 1 + - ( + (((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) + * t + + 0.254829592 + ) + * t + * (-(x.square())).exp() + ) + class BertIntermediate: - def __init__(self, hidden_size, intermediate_size): - self.dense = Linear(hidden_size, intermediate_size) + def __init__(self, hidden_size, intermediate_size): + self.dense = Linear(hidden_size, intermediate_size) + + def __call__(self, hidden_states): + x = self.dense(hidden_states) + # tinygrad gelu is openai gelu but we need the original bert gelu + return x * 0.5 * (1.0 + erf(x / 1.41421)) - def __call__(self, hidden_states): - x = self.dense(hidden_states) - # tinygrad gelu is openai gelu but we need the original bert gelu - return x * 0.5 * (1.0 + erf(x / 1.41421)) class BertAttention: - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): - self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) - self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + def __init__( + self, + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + hidden_dropout_prob, + ): + self.self = BertSelfAttention( + hidden_size, num_attention_heads, attention_probs_dropout_prob + ) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask): + self_output = self.self(hidden_states, attention_mask) + attention_output = self.output(self_output, hidden_states) + return attention_output - def __call__(self, hidden_states, attention_mask): - self_output = self.self(hidden_states, attention_mask) - attention_output = self.output(self_output, hidden_states) - return attention_output class BertSelfAttention: - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = Linear(hidden_size, self.all_head_size) - self.key = Linear(hidden_size, self.all_head_size) - self.value = Linear(hidden_size, self.all_head_size) + self.query = Linear(hidden_size, self.all_head_size) + self.key = Linear(hidden_size, self.all_head_size) + self.value = Linear(hidden_size, self.all_head_size) - self.dropout = attention_probs_dropout_prob + self.dropout = attention_probs_dropout_prob - def __call__(self, hidden_states, attention_mask): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + def __call__(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) - context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout) + context_layer = Tensor.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask, self.dropout + ) - context_layer = context_layer.transpose(1, 2) - context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size) + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape( + context_layer.shape[0], context_layer.shape[1], self.all_head_size + ) - return context_layer + return context_layer + + def transpose_for_scores(self, x): + x = x.reshape( + x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size + ) + return x.transpose(1, 2) - def transpose_for_scores(self, x): - x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size) - return x.transpose(1, 2) class BertSelfOutput: - def __init__(self, hidden_size, hidden_dropout_prob): - self.dense = Linear(hidden_size, hidden_size) - self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) - self.dropout = hidden_dropout_prob + def __init__(self, hidden_size, hidden_dropout_prob): + self.dense = Linear(hidden_size, hidden_size) + self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) + self.dropout = hidden_dropout_prob - def __call__(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = hidden_states.dropout(self.dropout) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states + def __call__(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = hidden_states.dropout(self.dropout) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states diff --git a/extra/models/convnext.py b/extra/models/convnext.py index 591112ad1..a24e6afb0 100644 --- a/extra/models/convnext.py +++ b/extra/models/convnext.py @@ -2,64 +2,99 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear from tinygrad.helpers import fetch, get_child -class Block: - def __init__(self, dim): - self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = Linear(dim, 4 * dim) - self.pwconv2 = Linear(4 * dim, dim) - self.gamma = Tensor.ones(dim) - def __call__(self, x:Tensor): - return x + x.sequential([ - self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm, - self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2) - ]) +class Block: + def __init__(self, dim): + self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = Linear(dim, 4 * dim) + self.pwconv2 = Linear(4 * dim, dim) + self.gamma = Tensor.ones(dim) + + def __call__(self, x: Tensor): + return x + x.sequential( + [ + self.dwconv, + lambda x: x.permute(0, 2, 3, 1), + self.norm, + self.pwconv1, + Tensor.gelu, + self.pwconv2, + lambda x: (self.gamma * x).permute(0, 3, 1, 2), + ] + ) + class ConvNeXt: - def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]): - self.downsample_layers = [ - [Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)], - *[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)] - ] - self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))] - self.norm = LayerNorm(dims[-1]) - self.head = Linear(dims[-1], num_classes) + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + ): + self.downsample_layers = [ + [ + Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm2d(dims[0], eps=1e-6), + ], + *[ + [ + LayerNorm2d(dims[i], eps=1e-6), + Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ] + for i in range(len(dims) - 1) + ], + ] + self.stages = [ + [Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims)) + ] + self.norm = LayerNorm(dims[-1]) + self.head = Linear(dims[-1], num_classes) + + def __call__(self, x: Tensor): + for downsample, stage in zip(self.downsample_layers, self.stages): + x = x.sequential(downsample).sequential(stage) + return x.mean([-2, -1]).sequential([self.norm, self.head]) - def __call__(self, x:Tensor): - for downsample, stage in zip(self.downsample_layers, self.stages): - x = x.sequential(downsample).sequential(stage) - return x.mean([-2, -1]).sequential([self.norm, self.head]) # *** model definition is done *** versions = { - "tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]}, - "small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]}, - "base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]}, - "large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]}, - "xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]} + "tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]}, + "small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]}, + "base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]}, + "large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]}, + "xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}, } + def get_model(version, load_weights=False): - model = ConvNeXt(**versions[version]) - if load_weights: - from tinygrad.nn.state import torch_load - weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model'] - for k,v in weights.items(): - mv = get_child(model, k) - mv.assign(v.reshape(mv.shape).to(mv.device)).realize() - return model + model = ConvNeXt(**versions[version]) + if load_weights: + from tinygrad.nn.state import torch_load + + weights = torch_load( + fetch( + f"https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth" + ) + )["model"] + for k, v in weights.items(): + mv = get_child(model, k) + mv.assign(v.reshape(mv.shape).to(mv.device)).realize() + return model + if __name__ == "__main__": - model = get_model("tiny", True) + model = get_model("tiny", True) - # load image - from test.models.test_efficientnet import chicken_img, preprocess, _LABELS - img = Tensor(preprocess(chicken_img)) + # load image + from test.models.test_efficientnet import chicken_img, preprocess, _LABELS - Tensor.training = False - Tensor.no_grad = True + img = Tensor(preprocess(chicken_img)) - out = model(img).numpy() - print(_LABELS[out.argmax()]) + Tensor.training = False + Tensor.no_grad = True + + out = model(img).numpy() + print(_LABELS[out.argmax()]) diff --git a/extra/models/efficientnet.py b/extra/models/efficientnet.py index a8c81fe96..817128b79 100644 --- a/extra/models/efficientnet.py +++ b/extra/models/efficientnet.py @@ -4,161 +4,218 @@ from tinygrad.nn import BatchNorm2d from tinygrad.helpers import get_child, fetch from tinygrad.nn.state import torch_load + class MBConvBlock: - def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True): - oup = expand_ratio * input_filters - if expand_ratio != 1: - self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1) - self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats) - else: - self._expand_conv = None + def __init__( + self, + kernel_size, + strides, + expand_ratio, + input_filters, + output_filters, + se_ratio, + has_se, + track_running_stats=True, + ): + oup = expand_ratio * input_filters + if expand_ratio != 1: + self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1) + self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats) + else: + self._expand_conv = None - self.strides = strides - if strides == (2,2): - self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2 - else: - self.pad = [(kernel_size-1)//2]*4 + self.strides = strides + if strides == (2, 2): + self.pad = [(kernel_size - 1) // 2 - 1, (kernel_size - 1) // 2] * 2 + else: + self.pad = [(kernel_size - 1) // 2] * 4 - self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size) - self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats) + self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size) + self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats) - self.has_se = has_se - if self.has_se: - num_squeezed_channels = max(1, int(input_filters * se_ratio)) - self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1) - self._se_reduce_bias = Tensor.zeros(num_squeezed_channels) - self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1) - self._se_expand_bias = Tensor.zeros(oup) + self.has_se = has_se + if self.has_se: + num_squeezed_channels = max(1, int(input_filters * se_ratio)) + self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1) + self._se_reduce_bias = Tensor.zeros(num_squeezed_channels) + self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1) + self._se_expand_bias = Tensor.zeros(oup) - self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1) - self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats) + self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1) + self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats) - def __call__(self, inputs): - x = inputs - if self._expand_conv: - x = self._bn0(x.conv2d(self._expand_conv)).swish() - x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0]) - x = self._bn1(x).swish() + def __call__(self, inputs): + x = inputs + if self._expand_conv: + x = self._bn0(x.conv2d(self._expand_conv)).swish() + x = x.conv2d( + self._depthwise_conv, + padding=self.pad, + stride=self.strides, + groups=self._depthwise_conv.shape[0], + ) + x = self._bn1(x).swish() - if self.has_se: - x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4]) - x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish() - x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias) - x = x.mul(x_squeezed.sigmoid()) + if self.has_se: + x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4]) + x_squeezed = x_squeezed.conv2d( + self._se_reduce, self._se_reduce_bias + ).swish() + x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias) + x = x.mul(x_squeezed.sigmoid()) + + x = self._bn2(x.conv2d(self._project_conv)) + if x.shape == inputs.shape: + x = x.add(inputs) + return x - x = self._bn2(x.conv2d(self._project_conv)) - if x.shape == inputs.shape: - x = x.add(inputs) - return x class EfficientNet: - def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True): - self.number = number - global_params = [ - # width, depth - (1.0, 1.0), # b0 - (1.0, 1.1), # b1 - (1.1, 1.2), # b2 - (1.2, 1.4), # b3 - (1.4, 1.8), # b4 - (1.6, 2.2), # b5 - (1.8, 2.6), # b6 - (2.0, 3.1), # b7 - (2.2, 3.6), # b8 - (4.3, 5.3), # l2 - ][max(number,0)] + def __init__( + self, + number=0, + classes=1000, + has_se=True, + track_running_stats=True, + input_channels=3, + has_fc_output=True, + ): + self.number = number + global_params = [ + # width, depth + (1.0, 1.0), # b0 + (1.0, 1.1), # b1 + (1.1, 1.2), # b2 + (1.2, 1.4), # b3 + (1.4, 1.8), # b4 + (1.6, 2.2), # b5 + (1.8, 2.6), # b6 + (2.0, 3.1), # b7 + (2.2, 3.6), # b8 + (4.3, 5.3), # l2 + ][max(number, 0)] - def round_filters(filters): - multiplier = global_params[0] - divisor = 8 - filters *= multiplier - new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor) - if new_filters < 0.9 * filters: # prevent rounding by more than 10% - new_filters += divisor - return int(new_filters) + def round_filters(filters): + multiplier = global_params[0] + divisor = 8 + filters *= multiplier + new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) - def round_repeats(repeats): - return int(math.ceil(global_params[1] * repeats)) + def round_repeats(repeats): + return int(math.ceil(global_params[1] * repeats)) - out_channels = round_filters(32) - self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3) - self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats) - blocks_args = [ - [1, 3, (1,1), 1, 32, 16, 0.25], - [2, 3, (2,2), 6, 16, 24, 0.25], - [2, 5, (2,2), 6, 24, 40, 0.25], - [3, 3, (2,2), 6, 40, 80, 0.25], - [3, 5, (1,1), 6, 80, 112, 0.25], - [4, 5, (2,2), 6, 112, 192, 0.25], - [1, 3, (1,1), 6, 192, 320, 0.25], - ] + out_channels = round_filters(32) + self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3) + self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats) + blocks_args = [ + [1, 3, (1, 1), 1, 32, 16, 0.25], + [2, 3, (2, 2), 6, 16, 24, 0.25], + [2, 5, (2, 2), 6, 24, 40, 0.25], + [3, 3, (2, 2), 6, 40, 80, 0.25], + [3, 5, (1, 1), 6, 80, 112, 0.25], + [4, 5, (2, 2), 6, 112, 192, 0.25], + [1, 3, (1, 1), 6, 192, 320, 0.25], + ] - if self.number == -1: - blocks_args = [ - [1, 3, (2,2), 1, 32, 40, 0.25], - [1, 3, (2,2), 1, 40, 80, 0.25], - [1, 3, (2,2), 1, 80, 192, 0.25], - [1, 3, (2,2), 1, 192, 320, 0.25], - ] - elif self.number == -2: - blocks_args = [ - [1, 9, (8,8), 1, 32, 320, 0.25], - ] + if self.number == -1: + blocks_args = [ + [1, 3, (2, 2), 1, 32, 40, 0.25], + [1, 3, (2, 2), 1, 40, 80, 0.25], + [1, 3, (2, 2), 1, 80, 192, 0.25], + [1, 3, (2, 2), 1, 192, 320, 0.25], + ] + elif self.number == -2: + blocks_args = [ + [1, 9, (8, 8), 1, 32, 320, 0.25], + ] - self._blocks = [] - for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args: - input_filters, output_filters = round_filters(input_filters), round_filters(output_filters) - for n in range(round_repeats(num_repeats)): - self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats)) - input_filters = output_filters - strides = (1,1) + self._blocks = [] + for ( + num_repeats, + kernel_size, + strides, + expand_ratio, + input_filters, + output_filters, + se_ratio, + ) in blocks_args: + input_filters, output_filters = round_filters(input_filters), round_filters( + output_filters + ) + for n in range(round_repeats(num_repeats)): + self._blocks.append( + MBConvBlock( + kernel_size, + strides, + expand_ratio, + input_filters, + output_filters, + se_ratio, + has_se=has_se, + track_running_stats=track_running_stats, + ) + ) + input_filters = output_filters + strides = (1, 1) - in_channels = round_filters(320) - out_channels = round_filters(1280) - self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1) - self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats) - if has_fc_output: - self._fc = Tensor.glorot_uniform(out_channels, classes) - self._fc_bias = Tensor.zeros(classes) - else: - self._fc = None + in_channels = round_filters(320) + out_channels = round_filters(1280) + self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1) + self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats) + if has_fc_output: + self._fc = Tensor.glorot_uniform(out_channels, classes) + self._fc_bias = Tensor.zeros(classes) + else: + self._fc = None - def forward(self, x): - x = self._bn0(x.conv2d(self._conv_stem, padding=(0,1,0,1), stride=2)).swish() - x = x.sequential(self._blocks) - x = self._bn1(x.conv2d(self._conv_head)).swish() - x = x.avg_pool2d(kernel_size=x.shape[2:4]) - x = x.reshape(shape=(-1, x.shape[1])) - return x.linear(self._fc, self._fc_bias) if self._fc is not None else x + def forward(self, x): + x = self._bn0(x.conv2d(self._conv_stem, padding=(0, 1, 0, 1), stride=2)).swish() + x = x.sequential(self._blocks) + x = self._bn1(x.conv2d(self._conv_head)).swish() + x = x.avg_pool2d(kernel_size=x.shape[2:4]) + x = x.reshape(shape=(-1, x.shape[1])) + return x.linear(self._fc, self._fc_bias) if self._fc is not None else x - def load_from_pretrained(self): - model_urls = { - 0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", - 1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", - 2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", - 3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", - 4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", - 5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", - 6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", - 7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" - } + def load_from_pretrained(self): + model_urls = { + 0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + 1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + 2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + 3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + 4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + 5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + 6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + 7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", + } - b0 = torch_load(fetch(model_urls[self.number])) - for k,v in b0.items(): - if k.endswith("num_batches_tracked"): continue - for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']: - if cat in k: - k = k.replace('.bias', '_bias') - k = k.replace('.weight', '') + b0 = torch_load(fetch(model_urls[self.number])) + for k, v in b0.items(): + if k.endswith("num_batches_tracked"): + continue + for cat in [ + "_conv_head", + "_conv_stem", + "_depthwise_conv", + "_expand_conv", + "_fc", + "_project_conv", + "_se_reduce", + "_se_expand", + ]: + if cat in k: + k = k.replace(".bias", "_bias") + k = k.replace(".weight", "") - #print(k, v.shape) - mv = get_child(self, k) - vnp = v #.astype(np.float32) - vnp = vnp if k != '_fc' else vnp.cpu().T - #vnp = vnp if vnp.shape != () else np.array([vnp]) - - if mv.shape == vnp.shape: - mv.assign(vnp.to(mv.device)) - else: - print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape)) + # print(k, v.shape) + mv = get_child(self, k) + vnp = v # .astype(np.float32) + vnp = vnp if k != "_fc" else vnp.cpu().T + # vnp = vnp if vnp.shape != () else np.array([vnp]) + if mv.shape == vnp.shape: + mv.assign(vnp.to(mv.device)) + else: + print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape)) diff --git a/extra/models/llama.py b/extra/models/llama.py index d8604c0e1..1af60a961 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -2,151 +2,275 @@ from typing import Tuple, Union, Optional, Dict from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device from tinygrad.helpers import getenv + # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: - freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim)) - freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0) - return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2) + freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim)) + freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0) + return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape( + 1, end, 1, dim // 2, 2 + ) + # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc) def complex_mult(A, c, d): - a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2] - ro = a*c - b*d - co = a*d + b*c - return ro.cat(co, dim=-1) + a, b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2] + ro = a * c - b * d + co = a * d + b * c + return ro.cat(co, dim=-1) + def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]: - assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}" - xq = xq.reshape(*xq.shape[0:-1], -1, 2) - xk = xk.reshape(*xk.shape[0:-1], -1, 2) - assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5 - c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2] - xq_out = complex_mult(xq, c, d) - xk_out = complex_mult(xk, c, d) - return xq_out.flatten(3), xk_out.flatten(3) + assert ( + freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1] + ), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}" + xq = xq.reshape(*xq.shape[0:-1], -1, 2) + xk = xk.reshape(*xk.shape[0:-1], -1, 2) + assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5 + c, d = ( + freqs_cis[:, : xq.shape[1], :, :, 0:1], + freqs_cis[:, : xq.shape[1], :, :, 1:2], + ) + xq_out = complex_mult(xq, c, d) + xk_out = complex_mult(xk, c, d) + return xq_out.flatten(3), xk_out.flatten(3) + + +def repeat_kv(x: Tensor, n_rep: int) -> Tensor: + bs, seqlen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x.reshape(bs, seqlen, n_kv_heads, 1, head_dim) + .expand(bs, seqlen, n_kv_heads, n_rep, head_dim) + .reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) + ) -def repeat_kv(x:Tensor, n_rep:int) -> Tensor: - bs, seqlen, n_kv_heads, head_dim = x.shape - if n_rep == 1: return x - return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim) class RMSNorm: - def __init__(self, dim, eps=1e-6): - self.eps = eps - self.weight = Tensor.ones(dim) + def __init__(self, dim, eps=1e-6): + self.eps = eps + self.weight = Tensor.ones(dim) + + def __call__(self, x: Tensor): + # TODO: convert to float? + return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight - def __call__(self, x:Tensor): - # TODO: convert to float? - return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight class Attention: - def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear): - self.n_heads = n_heads - self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1] - self.head_dim = dim // n_heads - self.n_rep = self.n_heads // self.n_kv_heads - self.max_context = max_context + def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear): + self.n_heads = n_heads + self.n_kv_heads = ( + n_kv_heads if n_kv_heads is not None else n_heads + ) # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1] + self.head_dim = dim // n_heads + self.n_rep = self.n_heads // self.n_kv_heads + self.max_context = max_context - self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) - self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) + self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) + self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) - def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor: - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) - xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) - xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - bsz, seqlen, n_heads, head_dim = xq.shape + def __call__( + self, + x: Tensor, + start_pos: Union[Variable, int], + freqs_cis: Tensor, + mask: Optional[Tensor], + ) -> Tensor: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) + xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) + xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + bsz, seqlen, n_heads, head_dim = xq.shape - # create kv cache - if not hasattr(self, "cache_k"): - self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim) + # create kv cache + if not hasattr(self, "cache_k"): + self.cache_k, self.cache_v = Tensor.zeros( + bsz, self.max_context, self.n_kv_heads, self.head_dim + ), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim) - keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) - values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) + keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1) + values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1) - # update the cache - self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize() - self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize() + # update the cache + self.cache_k.assign( + keys.pad( + (None, (0, self.max_context - start_pos - seqlen), None, None) + ).contiguous() + ).realize() + self.cache_v.assign( + values.pad( + (None, (0, self.max_context - start_pos - seqlen), None, None) + ).contiguous() + ).realize() - keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) + keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) + + xq, keys, values = ( + xq.transpose(1, 2), + keys.transpose(1, 2), + values.transpose(1, 2), + ) + attn = ( + xq.scaled_dot_product_attention(keys, values, mask) + .transpose(1, 2) + .reshape(bsz, seqlen, -1) + ) + return self.wo(attn) - xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) - attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1) - return self.wo(attn) class FeedForward: - def __init__(self, dim, hidden_dim, linear=nn.Linear): - self.w1 = linear(dim, hidden_dim, bias=False) - self.w2 = linear(hidden_dim, dim, bias=False) - self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit + def __init__(self, dim, hidden_dim, linear=nn.Linear): + self.w1 = linear(dim, hidden_dim, bias=False) + self.w2 = linear(hidden_dim, dim, bias=False) + self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit + + def __call__(self, x: Tensor) -> Tensor: + return self.w2( + self.w1(x).silu() * self.w3(x) + ) # SwiGLU [arxiv/2002.05202, eq (5)] - def __call__(self, x:Tensor) -> Tensor: - return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)] class TransformerBlock: - def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear): - self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear) - self.feed_forward = FeedForward(dim, hidden_dim, linear) - self.attention_norm = RMSNorm(dim, norm_eps) - self.ffn_norm = RMSNorm(dim, norm_eps) + def __init__( + self, + dim: int, + hidden_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + max_context: int, + linear=nn.Linear, + ): + self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear) + self.feed_forward = FeedForward(dim, hidden_dim, linear) + self.attention_norm = RMSNorm(dim, norm_eps) + self.ffn_norm = RMSNorm(dim, norm_eps) + + def __call__( + self, + x: Tensor, + start_pos: Union[Variable, int], + freqs_cis: Tensor, + mask: Optional[Tensor], + ): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + return (h + self.feed_forward(self.ffn_norm(h))).realize() - def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): - h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - return (h + self.feed_forward(self.ffn_norm(h))).realize() class Transformer: - def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True): - self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)] - self.norm = RMSNorm(dim, norm_eps) - self.tok_embeddings = nn.Embedding(vocab_size, dim) - self.output = linear(dim, vocab_size, bias=False) - self.max_context = max_context - self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta) - self.forward_jit = TinyJit(self.forward) if jit else None + def __init__( + self, + dim: int, + hidden_dim: int, + n_heads: int, + n_layers: int, + norm_eps: float, + vocab_size, + linear=nn.Linear, + n_kv_heads=None, + rope_theta=10000, + max_context=1024, + jit=True, + ): + self.layers = [ + TransformerBlock( + dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear + ) + for _ in range(n_layers) + ] + self.norm = RMSNorm(dim, norm_eps) + self.tok_embeddings = nn.Embedding(vocab_size, dim) + self.output = linear(dim, vocab_size, bias=False) + self.max_context = max_context + self.freqs_cis = precompute_freqs_cis( + dim // n_heads, self.max_context * 2, rope_theta + ) + self.forward_jit = TinyJit(self.forward) if jit else None - def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0): - _bsz, seqlen = tokens.shape - freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None)) - mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None + def forward( + self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float = 0.0 + ): + _bsz, seqlen = tokens.shape + freqs_cis = self.freqs_cis.shrink( + (None, (start_pos, start_pos + seqlen), None, None, None) + ) + mask = ( + Tensor.full( + (1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32 + ) + .triu(start_pos + 1) + .realize() + if seqlen > 1 + else None + ) - h = self.tok_embeddings(tokens) - for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) - logits = self.output(self.norm(h)) - return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize() + h = self.tok_embeddings(tokens) + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + logits = self.output(self.norm(h)) + return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize() + + def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0): + # TODO: better way to handle the first call v.s. the rest? + if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1): + assert start_pos > 0 + return self.forward_jit( + tokens, + Variable("start_pos", 1, self.max_context).bind(start_pos), + temperature, + ) + return self.forward(tokens, start_pos, temperature) - def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0): - # TODO: better way to handle the first call v.s. the rest? - if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1): - assert start_pos > 0 - return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature) - return self.forward(tokens, start_pos, temperature) # *** helpers *** -def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int): - def permute(v: Tensor, n_heads: int): - return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2]) - keymap = { - "model.embed_tokens.weight": "tok_embeddings.weight", - **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, - **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, - **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, - **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - sd = {} - for k, v in weights.items(): - if ".rotary_emb." in k: continue - v = v.to(Device.DEFAULT) - if "model.layers" in k: - if "q_proj" in k: - v = permute(v, n_heads) - elif "k_proj" in k: - v = permute(v, n_kv_heads) - sd[keymap[k]] = v - return sd \ No newline at end of file +def convert_from_huggingface( + weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int +): + def permute(v: Tensor, n_heads: int): + return ( + v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]) + .transpose(1, 2) + .reshape(*v.shape[:2]) + ) + + keymap = { + "model.embed_tokens.weight": "tok_embeddings.weight", + **{ + f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" + for l in range(len(model.layers)) + }, + **{ + f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" + for x in ["q", "k", "v", "o"] + for l in range(len(model.layers)) + }, + **{ + f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" + for l in range(len(model.layers)) + }, + **{ + f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" + for x, y in {"gate": "1", "down": "2", "up": "3"}.items() + for l in range(len(model.layers)) + }, + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + sd = {} + for k, v in weights.items(): + if ".rotary_emb." in k: + continue + v = v.to(Device.DEFAULT) + if "model.layers" in k: + if "q_proj" in k: + v = permute(v, n_heads) + elif "k_proj" in k: + v = permute(v, n_kv_heads) + sd[keymap[k]] = v + return sd diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index e352ad758..c4fd7959b 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -10,94 +10,126 @@ from tinygrad.nn.state import torch_load from extra.models.resnet import ResNet from extra.models.retinanet import nms as _box_nms -USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0' +USE_NP_GATHER = os.getenv("FULL_TINYGRAD", "0") == "0" + def rint(tensor): - x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2 - return (x<0).where(x.floor(), x.ceil()) + x = (tensor * 2).cast(dtypes.int32).contiguous().cast(dtypes.float32) / 2 + return (x < 0).where(x.floor(), x.ceil()) + def nearest_interpolate(tensor, scale_factor): - bs, c, py, px = tensor.shape - return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor) + bs, c, py, px = tensor.shape + return ( + tensor.reshape(bs, c, py, 1, px, 1) + .expand(bs, c, py, scale_factor, px, scale_factor) + .reshape(bs, c, py * scale_factor, px * scale_factor) + ) + def meshgrid(x, y): - grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]) - grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0]) - return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) + grid_x = Tensor.cat( + *[x[idx : idx + 1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])] + ) + grid_y = Tensor.cat(*[y.unsqueeze(0)] * x.shape[0]) + return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) + def topk(input_, k, dim=-1, largest=True, sorted=False): - k = min(k, input_.shape[dim]-1) - input_ = input_.numpy() - if largest: input_ *= -1 - ind = np.argpartition(input_, k, axis=dim) - if largest: input_ *= -1 - ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices - input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values - if not sorted: return Tensor(input_), ind - if largest: input_ *= -1 - ind_part = np.argsort(input_, axis=dim) - ind = np.take_along_axis(ind, ind_part, axis=dim) - if largest: input_ *= -1 - val = np.take_along_axis(input_, ind_part, axis=dim) - return Tensor(val), ind + k = min(k, input_.shape[dim] - 1) + input_ = input_.numpy() + if largest: + input_ *= -1 + ind = np.argpartition(input_, k, axis=dim) + if largest: + input_ *= -1 + ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices + input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values + if not sorted: + return Tensor(input_), ind + if largest: + input_ *= -1 + ind_part = np.argsort(input_, axis=dim) + ind = np.take_along_axis(ind, ind_part, axis=dim) + if largest: + input_ *= -1 + val = np.take_along_axis(input_, ind_part, axis=dim) + return Tensor(val), ind + # This is very slow for large arrays, or indices def _gather(array, indices): - indices = indices.float().to(array.device) - reshape_arg = [1]*array.ndim + [array.shape[-1]] - return Tensor.where( - indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]), - array, 0, - ).sum(indices.ndim) + indices = indices.float().to(array.device) + reshape_arg = [1] * array.ndim + [array.shape[-1]] + return Tensor.where( + indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) + == Tensor.arange(array.shape[-1]) + .reshape(*reshape_arg) + .expand(*indices.shape, array.shape[-1]), + array, + 0, + ).sum(indices.ndim) + # TODO: replace npgather with a faster gather using tinygrad only # NOTE: this blocks the gradient -def npgather(array,indices): - if isinstance(array, Tensor): array = array.numpy() - if isinstance(indices, Tensor): indices = indices.numpy() - if isinstance(indices, list): indices = np.asarray(indices) - return Tensor(array[indices.astype(int)]) +def npgather(array, indices): + if isinstance(array, Tensor): + array = array.numpy() + if isinstance(indices, Tensor): + indices = indices.numpy() + if isinstance(indices, list): + indices = np.asarray(indices) + return Tensor(array[indices.astype(int)]) + def get_strides(shape): - prod = [1] - for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx]) - # something about ints is broken with gpu, cuda - return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu() + prod = [1] + for idx in range(len(shape) - 1, -1, -1): + prod.append(prod[-1] * shape[idx]) + # something about ints is broken with gpu, cuda + return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu() + # with keys as integer array for all axes def tensor_getitem(tensor, *keys): - # something about ints is broken with gpu, cuda - flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32) - strides = get_strides(tensor.shape) - idxs = (flat_keys * strides).sum(1) - gatherer = npgather if USE_NP_GATHER else _gather - return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape) + # something about ints is broken with gpu, cuda + flat_keys = ( + Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1) + .cpu() + .cast(dtypes.int32) + ) + strides = get_strides(tensor.shape) + idxs = (flat_keys * strides).sum(1) + gatherer = npgather if USE_NP_GATHER else _gather + return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape) # for gather with indicies only on axis=0 def tensor_gather(tensor, indices): - if not isinstance(indices, Tensor): - indices = Tensor(indices, requires_grad=False) - if len(tensor.shape) > 2: - rem_shape = list(tensor.shape)[1:] - tensor = tensor.reshape(tensor.shape[0], -1) - else: - rem_shape = None - if len(tensor.shape) > 1: - tensor = tensor.T - repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]] - indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg) - ret = _gather(tensor, indices) - if rem_shape: - ret = ret.reshape([indices.shape[0]] + rem_shape) - else: - ret = _gather(tensor, indices) - del indices - return ret + if not isinstance(indices, Tensor): + indices = Tensor(indices, requires_grad=False) + if len(tensor.shape) > 2: + rem_shape = list(tensor.shape)[1:] + tensor = tensor.reshape(tensor.shape[0], -1) + else: + rem_shape = None + if len(tensor.shape) > 1: + tensor = tensor.T + repeat_arg = [1] * (tensor.ndim - 1) + [tensor.shape[-2]] + indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg) + ret = _gather(tensor, indices) + if rem_shape: + ret = ret.reshape([indices.shape[0]] + rem_shape) + else: + ret = _gather(tensor, indices) + del indices + return ret class LastLevelMaxPool: - def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)] + def __call__(self, x): + return [Tensor.max_pool2d(x, 1, 2)] # transpose @@ -105,1167 +137,1229 @@ FLIP_LEFT_RIGHT = 0 FLIP_TOP_BOTTOM = 1 -def permute_and_flatten(layer:Tensor, N, A, C, H, W): - layer = layer.reshape(N, -1, C, H, W) - layer = layer.permute(0, 3, 4, 1, 2) - layer = layer.reshape(N, -1, C) - return layer +def permute_and_flatten(layer: Tensor, N, A, C, H, W): + layer = layer.reshape(N, -1, C, H, W) + layer = layer.permute(0, 3, 4, 1, 2) + layer = layer.reshape(N, -1, C) + return layer class BoxList: - def __init__(self, bbox, image_size, mode="xyxy"): - if not isinstance(bbox, Tensor): - bbox = Tensor(bbox) - if bbox.ndim != 2: - raise ValueError( - "bbox should have 2 dimensions, got {}".format(bbox.ndim) - ) - if bbox.shape[-1] != 4: - raise ValueError( - "last dimenion of bbox should have a " - "size of 4, got {}".format(bbox.shape[-1]) - ) - if mode not in ("xyxy", "xywh"): - raise ValueError("mode should be 'xyxy' or 'xywh'") + def __init__(self, bbox, image_size, mode="xyxy"): + if not isinstance(bbox, Tensor): + bbox = Tensor(bbox) + if bbox.ndim != 2: + raise ValueError("bbox should have 2 dimensions, got {}".format(bbox.ndim)) + if bbox.shape[-1] != 4: + raise ValueError( + "last dimenion of bbox should have a " + "size of 4, got {}".format(bbox.shape[-1]) + ) + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") - self.bbox = bbox - self.size = image_size # (image_width, image_height) - self.mode = mode - self.extra_fields = {} + self.bbox = bbox + self.size = image_size # (image_width, image_height) + self.mode = mode + self.extra_fields = {} - def __repr__(self): - s = self.__class__.__name__ + "(" - s += "num_boxes={}, ".format(len(self)) - s += "image_width={}, ".format(self.size[0]) - s += "image_height={}, ".format(self.size[1]) - s += "mode={})".format(self.mode) - return s + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_boxes={}, ".format(len(self)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s - def area(self): - box = self.bbox - if self.mode == "xyxy": - TO_REMOVE = 1 - area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) - elif self.mode == "xywh": - area = box[:, 2] * box[:, 3] - return area + def area(self): + box = self.bbox + if self.mode == "xyxy": + TO_REMOVE = 1 + area = (box[:, 2] - box[:, 0] + TO_REMOVE) * ( + box[:, 3] - box[:, 1] + TO_REMOVE + ) + elif self.mode == "xywh": + area = box[:, 2] * box[:, 3] + return area - def add_field(self, field, field_data): - self.extra_fields[field] = field_data + def add_field(self, field, field_data): + self.extra_fields[field] = field_data - def get_field(self, field): - return self.extra_fields[field] + def get_field(self, field): + return self.extra_fields[field] - def has_field(self, field): - return field in self.extra_fields + def has_field(self, field): + return field in self.extra_fields - def fields(self): - return list(self.extra_fields.keys()) + def fields(self): + return list(self.extra_fields.keys()) - def _copy_extra_fields(self, bbox): - for k, v in bbox.extra_fields.items(): - self.extra_fields[k] = v + def _copy_extra_fields(self, bbox): + for k, v in bbox.extra_fields.items(): + self.extra_fields[k] = v - def convert(self, mode): - if mode == self.mode: - return self - xmin, ymin, xmax, ymax = self._split_into_xyxy() - if mode == "xyxy": - bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1) - bbox = BoxList(bbox, self.size, mode=mode) - else: - TO_REMOVE = 1 - bbox = Tensor.cat( - *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 - ) - bbox = BoxList(bbox, self.size, mode=mode) - bbox._copy_extra_fields(self) - return bbox + def convert(self, mode): + if mode == self.mode: + return self + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if mode == "xyxy": + bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1) + bbox = BoxList(bbox, self.size, mode=mode) + else: + TO_REMOVE = 1 + bbox = Tensor.cat( + *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 + ) + bbox = BoxList(bbox, self.size, mode=mode) + bbox._copy_extra_fields(self) + return bbox - def _split_into_xyxy(self): - if self.mode == "xyxy": - xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1) - return xmin, ymin, xmax, ymax - if self.mode == "xywh": - TO_REMOVE = 1 - xmin, ymin, w, h = self.bbox.chunk(4, dim=-1) - return ( - xmin, - ymin, - xmin + (w - TO_REMOVE).clamp(min=0), - ymin + (h - TO_REMOVE).clamp(min=0), - ) + def _split_into_xyxy(self): + if self.mode == "xyxy": + xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1) + return xmin, ymin, xmax, ymax + if self.mode == "xywh": + TO_REMOVE = 1 + xmin, ymin, w, h = self.bbox.chunk(4, dim=-1) + return ( + xmin, + ymin, + xmin + (w - TO_REMOVE).clamp(min=0), + ymin + (h - TO_REMOVE).clamp(min=0), + ) - def resize(self, size, *args, **kwargs): - ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) - if ratios[0] == ratios[1]: - ratio = ratios[0] - scaled_box = self.bbox * ratio - bbox = BoxList(scaled_box, size, mode=self.mode) - for k, v in self.extra_fields.items(): - if not isinstance(v, Tensor): - v = v.resize(size, *args, **kwargs) - bbox.add_field(k, v) - return bbox + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_box = self.bbox * ratio + bbox = BoxList(scaled_box, size, mode=self.mode) + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + return bbox - ratio_width, ratio_height = ratios - xmin, ymin, xmax, ymax = self._split_into_xyxy() - scaled_xmin = xmin * ratio_width - scaled_xmax = xmax * ratio_width - scaled_ymin = ymin * ratio_height - scaled_ymax = ymax * ratio_height - scaled_box = Tensor.cat( - *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 - ) - bbox = BoxList(scaled_box, size, mode="xyxy") - for k, v in self.extra_fields.items(): - if not isinstance(v, Tensor): - v = v.resize(size, *args, **kwargs) - bbox.add_field(k, v) + ratio_width, ratio_height = ratios + xmin, ymin, xmax, ymax = self._split_into_xyxy() + scaled_xmin = xmin * ratio_width + scaled_xmax = xmax * ratio_width + scaled_ymin = ymin * ratio_height + scaled_ymax = ymax * ratio_height + scaled_box = Tensor.cat( + *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 + ) + bbox = BoxList(scaled_box, size, mode="xyxy") + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) - return bbox.convert(self.mode) + return bbox.convert(self.mode) - def transpose(self, method): - image_width, image_height = self.size - xmin, ymin, xmax, ymax = self._split_into_xyxy() - if method == FLIP_LEFT_RIGHT: - TO_REMOVE = 1 - transposed_xmin = image_width - xmax - TO_REMOVE - transposed_xmax = image_width - xmin - TO_REMOVE - transposed_ymin = ymin - transposed_ymax = ymax - elif method == FLIP_TOP_BOTTOM: - transposed_xmin = xmin - transposed_xmax = xmax - transposed_ymin = image_height - ymax - transposed_ymax = image_height - ymin + def transpose(self, method): + image_width, image_height = self.size + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if method == FLIP_LEFT_RIGHT: + TO_REMOVE = 1 + transposed_xmin = image_width - xmax - TO_REMOVE + transposed_xmax = image_width - xmin - TO_REMOVE + transposed_ymin = ymin + transposed_ymax = ymax + elif method == FLIP_TOP_BOTTOM: + transposed_xmin = xmin + transposed_xmax = xmax + transposed_ymin = image_height - ymax + transposed_ymax = image_height - ymin - transposed_boxes = Tensor.cat( - *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 - ) - bbox = BoxList(transposed_boxes, self.size, mode="xyxy") - for k, v in self.extra_fields.items(): - if not isinstance(v, Tensor): - v = v.transpose(method) - bbox.add_field(k, v) - return bbox.convert(self.mode) + transposed_boxes = Tensor.cat( + *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), + dim=-1, + ) + bbox = BoxList(transposed_boxes, self.size, mode="xyxy") + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.transpose(method) + bbox.add_field(k, v) + return bbox.convert(self.mode) - def clip_to_image(self, remove_empty=True): - TO_REMOVE = 1 - bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0] - bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1] - bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2] - bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3] - self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1) - if remove_empty: - box = self.bbox - keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) - return self[keep] - return self - - def __getitem__(self, item): - if isinstance(item, list): - if len(item) == 0: - return [] - if sum(item) == len(item) and isinstance(item[0], bool): + def clip_to_image(self, remove_empty=True): + TO_REMOVE = 1 + bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0] + bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1] + bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2] + bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3] + self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1) + if remove_empty: + box = self.bbox + keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) + return self[keep] return self - bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode) - for k, v in self.extra_fields.items(): - bbox.add_field(k, tensor_gather(v, item)) - return bbox - def __len__(self): - return self.bbox.shape[0] + def __getitem__(self, item): + if isinstance(item, list): + if len(item) == 0: + return [] + if sum(item) == len(item) and isinstance(item[0], bool): + return self + bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode) + for k, v in self.extra_fields.items(): + bbox.add_field(k, tensor_gather(v, item)) + return bbox + + def __len__(self): + return self.bbox.shape[0] def cat_boxlist(bboxes): - size = bboxes[0].size - mode = bboxes[0].mode - fields = set(bboxes[0].fields()) - cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0] - - if len(cat_box_list) > 0: - cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode) - else: - cat_boxes = BoxList(bboxes[0].bbox, size, mode) - for field in fields: - cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0] + size = bboxes[0].size + mode = bboxes[0].mode + fields = set(bboxes[0].fields()) + cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0] if len(cat_box_list) > 0: - data = Tensor.cat(*cat_field_list, dim=0) + cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode) else: - data = bboxes[0].get_field(field) + cat_boxes = BoxList(bboxes[0].bbox, size, mode) + for field in fields: + cat_field_list = [ + bbox.get_field(field) + for bbox in bboxes + if bbox.get_field(field).shape[0] > 0 + ] - cat_boxes.add_field(field, data) + if len(cat_box_list) > 0: + data = Tensor.cat(*cat_field_list, dim=0) + else: + data = bboxes[0].get_field(field) - return cat_boxes + cat_boxes.add_field(field, data) + + return cat_boxes class FPN: - def __init__(self, in_channels_list, out_channels): - self.inner_blocks, self.layer_blocks = [], [] - for in_channels in in_channels_list: - self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) - self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) - self.top_block = LastLevelMaxPool() + def __init__(self, in_channels_list, out_channels): + self.inner_blocks, self.layer_blocks = [], [] + for in_channels in in_channels_list: + self.inner_blocks.append( + nn.Conv2d(in_channels, out_channels, kernel_size=1) + ) + self.layer_blocks.append( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + ) + self.top_block = LastLevelMaxPool() - def __call__(self, x: Tensor): - last_inner = self.inner_blocks[-1](x[-1]) - results = [] - results.append(self.layer_blocks[-1](last_inner)) - for feature, inner_block, layer_block in zip( + def __call__(self, x: Tensor): + last_inner = self.inner_blocks[-1](x[-1]) + results = [] + results.append(self.layer_blocks[-1](last_inner)) + for feature, inner_block, layer_block in zip( x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] - ): - if not inner_block: - continue - inner_top_down = nearest_interpolate(last_inner, scale_factor=2) - inner_lateral = inner_block(feature) - last_inner = inner_lateral + inner_top_down - layer_result = layer_block(last_inner) - results.insert(0, layer_result) - last_results = self.top_block(results[-1]) - results.extend(last_results) + ): + if not inner_block: + continue + inner_top_down = nearest_interpolate(last_inner, scale_factor=2) + inner_lateral = inner_block(feature) + last_inner = inner_lateral + inner_top_down + layer_result = layer_block(last_inner) + results.insert(0, layer_result) + last_results = self.top_block(results[-1]) + results.extend(last_results) - return tuple(results) + return tuple(results) class ResNetFPN: - def __init__(self, resnet, out_channels=256): - self.out_channels = out_channels - self.body = resnet - in_channels_stage2 = 256 - in_channels_list = [ - in_channels_stage2, - in_channels_stage2 * 2, - in_channels_stage2 * 4, - in_channels_stage2 * 8, - ] - self.fpn = FPN(in_channels_list, out_channels) + def __init__(self, resnet, out_channels=256): + self.out_channels = out_channels + self.body = resnet + in_channels_stage2 = 256 + in_channels_list = [ + in_channels_stage2, + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + self.fpn = FPN(in_channels_list, out_channels) - def __call__(self, x): - x = self.body(x) - return self.fpn(x) + def __call__(self, x): + x = self.body(x) + return self.fpn(x) class AnchorGenerator: - def __init__( - self, - sizes=(32, 64, 128, 256, 512), - aspect_ratios=(0.5, 1.0, 2.0), - anchor_strides=(4, 8, 16, 32, 64), - straddle_thresh=0, - ): - if len(anchor_strides) == 1: - anchor_stride = anchor_strides[0] - cell_anchors = [ - generate_anchors(anchor_stride, sizes, aspect_ratios) - ] - else: - if len(anchor_strides) != len(sizes): - raise RuntimeError("FPN should have #anchor_strides == #sizes") - - cell_anchors = [ - generate_anchors( - anchor_stride, - size if isinstance(size, (tuple, list)) else (size,), - aspect_ratios - ) - for anchor_stride, size in zip(anchor_strides, sizes) - ] - self.strides = anchor_strides - self.cell_anchors = cell_anchors - self.straddle_thresh = straddle_thresh - - def num_anchors_per_location(self): - return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors] - - def grid_anchors(self, grid_sizes): - anchors = [] - for size, stride, base_anchors in zip( - grid_sizes, self.strides, self.cell_anchors + def __init__( + self, + sizes=(32, 64, 128, 256, 512), + aspect_ratios=(0.5, 1.0, 2.0), + anchor_strides=(4, 8, 16, 32, 64), + straddle_thresh=0, ): - grid_height, grid_width = size - device = base_anchors.device - shifts_x = Tensor.arange( - start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device - ) - shifts_y = Tensor.arange( - start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device - ) - shift_y, shift_x = meshgrid(shifts_y, shifts_x) - shift_x = shift_x.reshape(-1) - shift_y = shift_y.reshape(-1) - shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + if len(anchor_strides) == 1: + anchor_stride = anchor_strides[0] + cell_anchors = [generate_anchors(anchor_stride, sizes, aspect_ratios)] + else: + if len(anchor_strides) != len(sizes): + raise RuntimeError("FPN should have #anchor_strides == #sizes") - anchors.append( - (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4) - ) + cell_anchors = [ + generate_anchors( + anchor_stride, + size if isinstance(size, (tuple, list)) else (size,), + aspect_ratios, + ) + for anchor_stride, size in zip(anchor_strides, sizes) + ] + self.strides = anchor_strides + self.cell_anchors = cell_anchors + self.straddle_thresh = straddle_thresh - return anchors + def num_anchors_per_location(self): + return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors] - def add_visibility_to(self, boxlist): - image_width, image_height = boxlist.size - anchors = boxlist.bbox - if self.straddle_thresh >= 0: - inds_inside = ( - (anchors[:, 0] >= -self.straddle_thresh) - * (anchors[:, 1] >= -self.straddle_thresh) - * (anchors[:, 2] < image_width + self.straddle_thresh) - * (anchors[:, 3] < image_height + self.straddle_thresh) - ) - else: - device = anchors.device - inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device) - boxlist.add_field("visibility", inds_inside) + def grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip( + grid_sizes, self.strides, self.cell_anchors + ): + grid_height, grid_width = size + device = base_anchors.device + shifts_x = Tensor.arange( + start=0, + stop=grid_width * stride, + step=stride, + dtype=dtypes.float32, + device=device, + ) + shifts_y = Tensor.arange( + start=0, + stop=grid_height * stride, + step=stride, + dtype=dtypes.float32, + device=device, + ) + shift_y, shift_x = meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1) - def __call__(self, image_list, feature_maps): - grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] - anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) - anchors = [] - for (image_height, image_width) in image_list.image_sizes: - anchors_in_image = [] - for anchors_per_feature_map in anchors_over_all_feature_maps: - boxlist = BoxList( - anchors_per_feature_map, (image_width, image_height), mode="xyxy" - ) - self.add_visibility_to(boxlist) - anchors_in_image.append(boxlist) - anchors.append(anchors_in_image) - return anchors + anchors.append( + (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape( + -1, 4 + ) + ) + + return anchors + + def add_visibility_to(self, boxlist): + image_width, image_height = boxlist.size + anchors = boxlist.bbox + if self.straddle_thresh >= 0: + inds_inside = ( + (anchors[:, 0] >= -self.straddle_thresh) + * (anchors[:, 1] >= -self.straddle_thresh) + * (anchors[:, 2] < image_width + self.straddle_thresh) + * (anchors[:, 3] < image_height + self.straddle_thresh) + ) + else: + device = anchors.device + inds_inside = Tensor.ones( + anchors.shape[0], dtype=dtypes.uint8, device=device + ) + boxlist.add_field("visibility", inds_inside) + + def __call__(self, image_list, feature_maps): + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) + anchors = [] + for image_height, image_width in image_list.image_sizes: + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + boxlist = BoxList( + anchors_per_feature_map, (image_width, image_height), mode="xyxy" + ) + self.add_visibility_to(boxlist) + anchors_in_image.append(boxlist) + anchors.append(anchors_in_image) + return anchors def generate_anchors( stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2) ): - return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))) + return _generate_anchors( + stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)) + ) def _generate_anchors(base_size, scales, aspect_ratios): - anchor = Tensor([1, 1, base_size, base_size]) - 1 - anchors = _ratio_enum(anchor, aspect_ratios) - anchors = Tensor.cat( - *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])] - ) - return anchors + anchor = Tensor([1, 1, base_size, base_size]) - 1 + anchors = _ratio_enum(anchor, aspect_ratios) + anchors = Tensor.cat( + *[ + _scale_enum(anchors[i, :], scales).reshape(-1, 4) + for i in range(anchors.shape[0]) + ] + ) + return anchors def _whctrs(anchor): - w = anchor[2] - anchor[0] + 1 - h = anchor[3] - anchor[1] + 1 - x_ctr = anchor[0] + 0.5 * (w - 1) - y_ctr = anchor[1] + 0.5 * (h - 1) - return w, h, x_ctr, y_ctr + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr def _mkanchors(ws, hs, x_ctr, y_ctr): - ws = ws[:, None] - hs = hs[:, None] - anchors = Tensor.cat(*( - x_ctr - 0.5 * (ws - 1), - y_ctr - 0.5 * (hs - 1), - x_ctr + 0.5 * (ws - 1), - y_ctr + 0.5 * (hs - 1), - ), dim=1) - return anchors + ws = ws[:, None] + hs = hs[:, None] + anchors = Tensor.cat( + *( + x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1), + ), + dim=1, + ) + return anchors def _ratio_enum(anchor, ratios): - w, h, x_ctr, y_ctr = _whctrs(anchor) - size = w * h - size_ratios = size / ratios - ws = rint(Tensor.sqrt(size_ratios)) - hs = rint(ws * ratios) - anchors = _mkanchors(ws, hs, x_ctr, y_ctr) - return anchors + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = rint(Tensor.sqrt(size_ratios)) + hs = rint(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors def _scale_enum(anchor, scales): - w, h, x_ctr, y_ctr = _whctrs(anchor) - ws = w * scales - hs = h * scales - anchors = _mkanchors(ws, hs, x_ctr, y_ctr) - return anchors + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = w * scales + hs = h * scales + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors class RPNHead: - def __init__(self, in_channels, num_anchors): - self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) - self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1) - self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1) + def __init__(self, in_channels, num_anchors): + self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) + self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1) + self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1) - def __call__(self, x): - logits = [] - bbox_reg = [] - for feature in x: - t = Tensor.relu(self.conv(feature)) - logits.append(self.cls_logits(t)) - bbox_reg.append(self.bbox_pred(t)) - return logits, bbox_reg + def __call__(self, x): + logits = [] + bbox_reg = [] + for feature in x: + t = Tensor.relu(self.conv(feature)) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg class BoxCoder(object): - def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): - self.weights = weights - self.bbox_xform_clip = bbox_xform_clip + def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): + self.weights = weights + self.bbox_xform_clip = bbox_xform_clip - def encode(self, reference_boxes, proposals): - TO_REMOVE = 1 # TODO remove - ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE - ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE - ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths - ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights + def encode(self, reference_boxes, proposals): + TO_REMOVE = 1 # TODO remove + ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE + ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE + ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths + ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights - gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE - gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE - gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths - gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights + gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE + gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE + gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths + gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights - wx, wy, ww, wh = self.weights - targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths - targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights - targets_dw = ww * Tensor.log(gt_widths / ex_widths) - targets_dh = wh * Tensor.log(gt_heights / ex_heights) + wx, wy, ww, wh = self.weights + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * Tensor.log(gt_widths / ex_widths) + targets_dh = wh * Tensor.log(gt_heights / ex_heights) - targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) - return targets + targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + return targets - def decode(self, rel_codes, boxes): - boxes = boxes.cast(rel_codes.dtype) - rel_codes = rel_codes + def decode(self, rel_codes, boxes): + boxes = boxes.cast(rel_codes.dtype) + rel_codes = rel_codes - TO_REMOVE = 1 # TODO remove - widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE - heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE - ctr_x = boxes[:, 0] + 0.5 * widths - ctr_y = boxes[:, 1] + 0.5 * heights + TO_REMOVE = 1 # TODO remove + widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE + heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights - wx, wy, ww, wh = self.weights - dx = rel_codes[:, 0::4] / wx - dy = rel_codes[:, 1::4] / wy - dw = rel_codes[:, 2::4] / ww - dh = rel_codes[:, 3::4] / wh + wx, wy, ww, wh = self.weights + dx = rel_codes[:, 0::4] / wx + dy = rel_codes[:, 1::4] / wy + dw = rel_codes[:, 2::4] / ww + dh = rel_codes[:, 3::4] / wh - # Prevent sending too large values into Tensor.exp() - dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip) - dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip) + # Prevent sending too large values into Tensor.exp() + dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip) + dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip) - pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] - pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] - pred_w = dw.exp() * widths[:, None] - pred_h = dh.exp() * heights[:, None] - x = pred_ctr_x - 0.5 * pred_w - y = pred_ctr_y - 0.5 * pred_h - w = pred_ctr_x + 0.5 * pred_w - 1 - h = pred_ctr_y + 0.5 * pred_h - 1 - pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1]) - return pred_boxes + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = dw.exp() * widths[:, None] + pred_h = dh.exp() * heights[:, None] + x = pred_ctr_x - 0.5 * pred_w + y = pred_ctr_y - 0.5 * pred_h + w = pred_ctr_x + 0.5 * pred_w - 1 + h = pred_ctr_y + 0.5 * pred_h - 1 + pred_boxes = ( + Tensor.stack([x, y, w, h]) + .permute(1, 2, 0) + .reshape(rel_codes.shape[0], rel_codes.shape[1]) + ) + return pred_boxes def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): - if nms_thresh <= 0: - return boxlist - mode = boxlist.mode - boxlist = boxlist.convert("xyxy") - boxes = boxlist.bbox - score = boxlist.get_field(score_field) - keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh) - if max_proposals > 0: - keep = keep[:max_proposals] - boxlist = boxlist[keep] - return boxlist.convert(mode) + if nms_thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + score = boxlist.get_field(score_field) + keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh) + if max_proposals > 0: + keep = keep[:max_proposals] + boxlist = boxlist[keep] + return boxlist.convert(mode) def remove_small_boxes(boxlist, min_size): - xywh_boxes = boxlist.convert("xywh").bbox - _, _, ws, hs = xywh_boxes.chunk(4, dim=1) - keep = (( - (ws >= min_size) * (hs >= min_size) - ) > 0).reshape(-1) - if keep.sum().numpy() == len(boxlist): - return boxlist - else: - keep = keep.numpy().nonzero()[0] - return boxlist[keep] + xywh_boxes = boxlist.convert("xywh").bbox + _, _, ws, hs = xywh_boxes.chunk(4, dim=1) + keep = (((ws >= min_size) * (hs >= min_size)) > 0).reshape(-1) + if keep.sum().numpy() == len(boxlist): + return boxlist + else: + keep = keep.numpy().nonzero()[0] + return boxlist[keep] class RPNPostProcessor: - # Not used in Loss calculation - def __init__( - self, - pre_nms_top_n, - post_nms_top_n, - nms_thresh, - min_size, - box_coder=None, - fpn_post_nms_top_n=None, - ): - self.pre_nms_top_n = pre_nms_top_n - self.post_nms_top_n = post_nms_top_n - self.nms_thresh = nms_thresh - self.min_size = min_size + # Not used in Loss calculation + def __init__( + self, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + box_coder=None, + fpn_post_nms_top_n=None, + ): + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size - if box_coder is None: - box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - self.box_coder = box_coder + if box_coder is None: + box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self.box_coder = box_coder - if fpn_post_nms_top_n is None: - fpn_post_nms_top_n = post_nms_top_n - self.fpn_post_nms_top_n = fpn_post_nms_top_n + if fpn_post_nms_top_n is None: + fpn_post_nms_top_n = post_nms_top_n + self.fpn_post_nms_top_n = fpn_post_nms_top_n - def forward_for_single_feature_map(self, anchors, objectness, box_regression): - device = objectness.device - N, A, H, W = objectness.shape - objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1) - objectness = objectness.sigmoid() + def forward_for_single_feature_map(self, anchors, objectness, box_regression): + device = objectness.device + N, A, H, W = objectness.shape + objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1) + objectness = objectness.sigmoid() - box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) + box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) - num_anchors = A * H * W + num_anchors = A * H * W - pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) - objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False) - concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4) - image_shapes = [box.size for box in anchors] + pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) + objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False) + concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4) + image_shapes = [box.size for box in anchors] - box_regression_list = [] - concat_anchors_list = [] - for batch_idx in range(N): - box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])) - concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])) + box_regression_list = [] + concat_anchors_list = [] + for batch_idx in range(N): + box_regression_list.append( + tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]) + ) + concat_anchors_list.append( + tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]) + ) - box_regression = Tensor.stack(box_regression_list) - concat_anchors = Tensor.stack(concat_anchors_list) + box_regression = Tensor.stack(box_regression_list) + concat_anchors = Tensor.stack(concat_anchors_list) - proposals = self.box_coder.decode( - box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4) - ) + proposals = self.box_coder.decode( + box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4) + ) - proposals = proposals.reshape(N, -1, 4) + proposals = proposals.reshape(N, -1, 4) - result = [] - for proposal, score, im_shape in zip(proposals, objectness, image_shapes): - boxlist = BoxList(proposal, im_shape, mode="xyxy") - boxlist.add_field("objectness", score) - boxlist = boxlist.clip_to_image(remove_empty=False) - boxlist = remove_small_boxes(boxlist, self.min_size) - boxlist = boxlist_nms( - boxlist, - self.nms_thresh, - max_proposals=self.post_nms_top_n, - score_field="objectness", - ) - result.append(boxlist) - return result + result = [] + for proposal, score, im_shape in zip(proposals, objectness, image_shapes): + boxlist = BoxList(proposal, im_shape, mode="xyxy") + boxlist.add_field("objectness", score) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + boxlist = boxlist_nms( + boxlist, + self.nms_thresh, + max_proposals=self.post_nms_top_n, + score_field="objectness", + ) + result.append(boxlist) + return result - def __call__(self, anchors, objectness, box_regression): - sampled_boxes = [] - num_levels = len(objectness) - anchors = list(zip(*anchors)) - for a, o, b in zip(anchors, objectness, box_regression): - sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) + def __call__(self, anchors, objectness, box_regression): + sampled_boxes = [] + num_levels = len(objectness) + anchors = list(zip(*anchors)) + for a, o, b in zip(anchors, objectness, box_regression): + sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) - boxlists = list(zip(*sampled_boxes)) - boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] - if num_levels > 1: - boxlists = self.select_over_all_levels(boxlists) + if num_levels > 1: + boxlists = self.select_over_all_levels(boxlists) - return boxlists + return boxlists - def select_over_all_levels(self, boxlists): - num_images = len(boxlists) - for i in range(num_images): - objectness = boxlists[i].get_field("objectness") - post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0]) - _, inds_sorted = topk(objectness, - post_nms_top_n, dim=0, sorted=False - ) - boxlists[i] = boxlists[i][inds_sorted] - return boxlists + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + for i in range(num_images): + objectness = boxlists[i].get_field("objectness") + post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0]) + _, inds_sorted = topk(objectness, post_nms_top_n, dim=0, sorted=False) + boxlists[i] = boxlists[i][inds_sorted] + return boxlists class RPN: - def __init__(self, in_channels): - self.anchor_generator = AnchorGenerator() + def __init__(self, in_channels): + self.anchor_generator = AnchorGenerator() - in_channels = 256 - head = RPNHead( - in_channels, self.anchor_generator.num_anchors_per_location()[0] - ) - rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) - box_selector_test = RPNPostProcessor( - pre_nms_top_n=1000, - post_nms_top_n=1000, - nms_thresh=0.7, - min_size=0, - box_coder=rpn_box_coder, - fpn_post_nms_top_n=1000 - ) - self.head = head - self.box_selector_test = box_selector_test + in_channels = 256 + head = RPNHead(in_channels, self.anchor_generator.num_anchors_per_location()[0]) + rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + box_selector_test = RPNPostProcessor( + pre_nms_top_n=1000, + post_nms_top_n=1000, + nms_thresh=0.7, + min_size=0, + box_coder=rpn_box_coder, + fpn_post_nms_top_n=1000, + ) + self.head = head + self.box_selector_test = box_selector_test - def __call__(self, images, features, targets=None): - objectness, rpn_box_regression = self.head(features) - anchors = self.anchor_generator(images, features) - boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) - return boxes, {} + def __call__(self, images, features, targets=None): + objectness, rpn_box_regression = self.head(features) + anchors = self.anchor_generator(images, features) + boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) + return boxes, {} def make_conv3x3( - in_channels, - out_channels, - dilation=1, - stride=1, - use_gn=False, -): - conv = nn.Conv2d( in_channels, out_channels, - kernel_size=3, - stride=stride, - padding=dilation, - dilation=dilation, - bias=False if use_gn else True - ) - return conv + dilation=1, + stride=1, + use_gn=False, +): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False if use_gn else True, + ) + return conv class MaskRCNNFPNFeatureExtractor: - def __init__(self): - resolution = 14 - scales = (0.25, 0.125, 0.0625, 0.03125) - sampling_ratio = 2 - pooler = Pooler( - output_size=(resolution, resolution), - scales=scales, - sampling_ratio=sampling_ratio, - ) - input_size = 256 - self.pooler = pooler + def __init__(self): + resolution = 14 + scales = (0.25, 0.125, 0.0625, 0.03125) + sampling_ratio = 2 + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = 256 + self.pooler = pooler - use_gn = False - layers = (256, 256, 256, 256) - dilation = 1 - self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn) - self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn) - self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn) - self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn) - self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4] + use_gn = False + layers = (256, 256, 256, 256) + dilation = 1 + self.mask_fcn1 = make_conv3x3( + input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn + ) + self.mask_fcn2 = make_conv3x3( + layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn + ) + self.mask_fcn3 = make_conv3x3( + layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn + ) + self.mask_fcn4 = make_conv3x3( + layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn + ) + self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4] - def __call__(self, x, proposals): - x = self.pooler(x, proposals) - for layer in self.blocks: - if x is not None: - x = Tensor.relu(layer(x)) - return x + def __call__(self, x, proposals): + x = self.pooler(x, proposals) + for layer in self.blocks: + if x is not None: + x = Tensor.relu(layer(x)) + return x class MaskRCNNC4Predictor: - def __init__(self): - num_classes = 81 - dim_reduced = 256 - num_inputs = dim_reduced - self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) - self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0) + def __init__(self): + num_classes = 81 + dim_reduced = 256 + num_inputs = dim_reduced + self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0) - def __call__(self, x): - x = Tensor.relu(self.conv5_mask(x)) - return self.mask_fcn_logits(x) + def __call__(self, x): + x = Tensor.relu(self.conv5_mask(x)) + return self.mask_fcn_logits(x) class FPN2MLPFeatureExtractor: - def __init__(self, cfg): - resolution = 7 - scales = (0.25, 0.125, 0.0625, 0.03125) - sampling_ratio = 2 - pooler = Pooler( - output_size=(resolution, resolution), - scales=scales, - sampling_ratio=sampling_ratio, - ) - input_size = 256 * resolution ** 2 - representation_size = 1024 - self.pooler = pooler - self.fc6 = nn.Linear(input_size, representation_size) - self.fc7 = nn.Linear(representation_size, representation_size) + def __init__(self, cfg): + resolution = 7 + scales = (0.25, 0.125, 0.0625, 0.03125) + sampling_ratio = 2 + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = 256 * resolution**2 + representation_size = 1024 + self.pooler = pooler + self.fc6 = nn.Linear(input_size, representation_size) + self.fc7 = nn.Linear(representation_size, representation_size) - def __call__(self, x, proposals): - x = self.pooler(x, proposals) - x = x.reshape(x.shape[0], -1) - x = Tensor.relu(self.fc6(x)) - x = Tensor.relu(self.fc7(x)) - return x + def __call__(self, x, proposals): + x = self.pooler(x, proposals) + x = x.reshape(x.shape[0], -1) + x = Tensor.relu(self.fc6(x)) + x = Tensor.relu(self.fc7(x)) + return x def _bilinear_interpolate( - input, # [N, C, H, W] - roi_batch_ind, # [K] - y, # [K, PH, IY] - x, # [K, PW, IX] - ymask, # [K, IY] - xmask, # [K, IX] -): - _, channels, height, width = input.shape - y = y.clip(min_=0.0, max_=float(height-1)) - x = x.clip(min_=0.0, max_=float(width-1)) - - # Tensor.where doesnt work well with int32 data so cast to float32 - y_low = y.cast(dtypes.int32).contiguous().float().contiguous() - x_low = x.cast(dtypes.int32).contiguous().float().contiguous() - - y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1) - y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low) - - x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1) - x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low) - - ly = y - y_low - lx = x - x_low - hy = 1.0 - ly - hx = 1.0 - lx - - def masked_index( + input, # [N, C, H, W] + roi_batch_ind, # [K] y, # [K, PH, IY] x, # [K, PW, IX] - ): - if ymask is not None: - assert xmask is not None - y = Tensor.where(ymask[:, None, :], y, 0) - x = Tensor.where(xmask[:, None, :], x, 0) - key1 = roi_batch_ind[:, None, None, None, None, None] - key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None] - key3 = y[:, None, :, None, :, None] - key4 = x[:, None, None, :, None, :] - return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX] + ymask, # [K, IY] + xmask, # [K, IX] +): + _, channels, height, width = input.shape + y = y.clip(min_=0.0, max_=float(height - 1)) + x = x.clip(min_=0.0, max_=float(width - 1)) - v1 = masked_index(y_low, x_low) - v2 = masked_index(y_low, x_high) - v3 = masked_index(y_high, x_low) - v4 = masked_index(y_high, x_high) + # Tensor.where doesnt work well with int32 data so cast to float32 + y_low = y.cast(dtypes.int32).contiguous().float().contiguous() + x_low = x.cast(dtypes.int32).contiguous().float().contiguous() - # all ws preemptively [K, C, PH, PW, IY, IX] - def outer_prod(y, x): - return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] + y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1) + y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low) - w1 = outer_prod(hy, hx) - w2 = outer_prod(hy, lx) - w3 = outer_prod(ly, hx) - w4 = outer_prod(ly, lx) + x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1) + x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low) - val = w1*v1 + w2*v2 + w3*v3 + w4*v4 - return val + ly = y - y_low + lx = x - x_low + hy = 1.0 - ly + hx = 1.0 - lx -#https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align -def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - orig_dtype = input.dtype - _, _, height, width = input.shape - ph = Tensor.arange(pooled_height, device=input.device) - pw = Tensor.arange(pooled_width, device=input.device) + def masked_index( + y, # [K, PH, IY] + x, # [K, PW, IX] + ): + if ymask is not None: + assert xmask is not None + y = Tensor.where(ymask[:, None, :], y, 0) + x = Tensor.where(xmask[:, None, :], x, 0) + key1 = roi_batch_ind[:, None, None, None, None, None] + key2 = Tensor.arange(channels, device=input.device)[ + None, :, None, None, None, None + ] + key3 = y[:, None, :, None, :, None] + key4 = x[:, None, None, :, None, :] + return tensor_getitem(input, key1, key2, key3, key4) # [K, C, PH, PW, IY, IX] - roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous() - offset = 0.5 if aligned else 0.0 - roi_start_w = rois[:, 1] * spatial_scale - offset - roi_start_h = rois[:, 2] * spatial_scale - offset - roi_end_w = rois[:, 3] * spatial_scale - offset - roi_end_h = rois[:, 4] * spatial_scale - offset + v1 = masked_index(y_low, x_low) + v2 = masked_index(y_low, x_high) + v3 = masked_index(y_high, x_low) + v4 = masked_index(y_high, x_high) - roi_width = roi_end_w - roi_start_w - roi_height = roi_end_h - roi_start_h - if not aligned: - roi_width = roi_width.maximum(1.0) - roi_height = roi_height.maximum(1.0) + # all ws preemptively [K, C, PH, PW, IY, IX] + def outer_prod(y, x): + return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] - bin_size_h = roi_height / pooled_height - bin_size_w = roi_width / pooled_width + w1 = outer_prod(hy, hx) + w2 = outer_prod(hy, lx) + w3 = outer_prod(ly, hx) + w4 = outer_prod(ly, lx) - exact_sampling = sampling_ratio > 0 - roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil() - roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil() + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + return val - if exact_sampling: - count = max(roi_bin_grid_h * roi_bin_grid_w, 1) - iy = Tensor.arange(roi_bin_grid_h, device=input.device) - ix = Tensor.arange(roi_bin_grid_w, device=input.device) - ymask = None - xmask = None - else: - count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1) - iy = Tensor.arange(height, device=input.device) - ix = Tensor.arange(width, device=input.device) - ymask = iy[None, :] < roi_bin_grid_h[:, None] - xmask = ix[None, :] < roi_bin_grid_w[:, None] - def from_K(t): - return t[:, None, None] +# https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align +def _roi_align( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned +): + orig_dtype = input.dtype + _, _, height, width = input.shape + ph = Tensor.arange(pooled_height, device=input.device) + pw = Tensor.arange(pooled_width, device=input.device) - y = ( - from_K(roi_start_h) - + ph[None, :, None] * from_K(bin_size_h) - + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) - ) - x = ( - from_K(roi_start_w) - + pw[None, :, None] * from_K(bin_size_w) - + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) - ) + roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous() + offset = 0.5 if aligned else 0.0 + roi_start_w = rois[:, 1] * spatial_scale - offset + roi_start_h = rois[:, 2] * spatial_scale - offset + roi_end_w = rois[:, 3] * spatial_scale - offset + roi_end_h = rois[:, 4] * spatial_scale - offset - val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) - if not exact_sampling: - val = ymask[:, None, None, None, :, None].where(val, 0) - val = xmask[:, None, None, None, None, :].where(val, 0) + roi_width = roi_end_w - roi_start_w + roi_height = roi_end_h - roi_start_h + if not aligned: + roi_width = roi_width.maximum(1.0) + roi_height = roi_height.maximum(1.0) - output = val.sum((-1, -2)) - if isinstance(count, Tensor): - output /= count[:, None, None, None] - else: - output /= count + bin_size_h = roi_height / pooled_height + bin_size_w = roi_width / pooled_width - output = output.cast(orig_dtype) - return output - -class ROIAlign: - def __init__(self, output_size, spatial_scale, sampling_ratio): - self.output_size = output_size - self.spatial_scale = spatial_scale - self.sampling_ratio = sampling_ratio - - def __call__(self, input, rois): - output = _roi_align( - input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False + exact_sampling = sampling_ratio > 0 + roi_bin_grid_h = ( + sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil() ) + roi_bin_grid_w = ( + sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil() + ) + + if exact_sampling: + count = max(roi_bin_grid_h * roi_bin_grid_w, 1) + iy = Tensor.arange(roi_bin_grid_h, device=input.device) + ix = Tensor.arange(roi_bin_grid_w, device=input.device) + ymask = None + xmask = None + else: + count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1) + iy = Tensor.arange(height, device=input.device) + ix = Tensor.arange(width, device=input.device) + ymask = iy[None, :] < roi_bin_grid_h[:, None] + xmask = ix[None, :] < roi_bin_grid_w[:, None] + + def from_K(t): + return t[:, None, None] + + y = ( + from_K(roi_start_h) + + ph[None, :, None] * from_K(bin_size_h) + + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + ) + x = ( + from_K(roi_start_w) + + pw[None, :, None] * from_K(bin_size_w) + + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + ) + + val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) + if not exact_sampling: + val = ymask[:, None, None, None, :, None].where(val, 0) + val = xmask[:, None, None, None, None, :].where(val, 0) + + output = val.sum((-1, -2)) + if isinstance(count, Tensor): + output /= count[:, None, None, None] + else: + output /= count + + output = output.cast(orig_dtype) return output -class LevelMapper: - def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): - self.k_min = k_min - self.k_max = k_max - self.s0 = canonical_scale - self.lvl0 = canonical_level - self.eps = eps +class ROIAlign: + def __init__(self, output_size, spatial_scale, sampling_ratio): + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio - def __call__(self, boxlists): - s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists])) - target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor() - target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max) - return target_lvls - self.k_min + def __call__(self, input, rois): + output = _roi_align( + input, + rois, + self.spatial_scale, + self.output_size[0], + self.output_size[1], + self.sampling_ratio, + aligned=False, + ) + return output + + +class LevelMapper: + def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): + self.k_min = k_min + self.k_max = k_max + self.s0 = canonical_scale + self.lvl0 = canonical_level + self.eps = eps + + def __call__(self, boxlists): + s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists])) + target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor() + target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max) + return target_lvls - self.k_min class Pooler: - def __init__(self, output_size, scales, sampling_ratio): - self.output_size = output_size - self.scales = scales - self.sampling_ratio = sampling_ratio - poolers = [] - for scale in scales: - poolers.append( - ROIAlign( - output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + def __init__(self, output_size, scales, sampling_ratio): + self.output_size = output_size + self.scales = scales + self.sampling_ratio = sampling_ratio + poolers = [] + for scale in scales: + poolers.append( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + ) + ) + self.poolers = poolers + self.output_size = output_size + lvl_min = -math.log2(scales[0]) + lvl_max = -math.log2(scales[-1]) + self.map_levels = LevelMapper(lvl_min, lvl_max) + + def convert_to_roi_format(self, boxes): + concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0) + device, dtype = concat_boxes.device, concat_boxes.dtype + ids = Tensor.cat( + *[ + Tensor.full((len(b), 1), i, dtype=dtype, device=device) + for i, b in enumerate(boxes) + ], + dim=0, ) - ) - self.poolers = poolers - self.output_size = output_size - lvl_min = -math.log2(scales[0]) - lvl_max = -math.log2(scales[-1]) - self.map_levels = LevelMapper(lvl_min, lvl_max) + if concat_boxes.shape[0] != 0: + rois = Tensor.cat(*[ids, concat_boxes], dim=1) + return rois - def convert_to_roi_format(self, boxes): - concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0) - device, dtype = concat_boxes.device, concat_boxes.dtype - ids = Tensor.cat( - *[ - Tensor.full((len(b), 1), i, dtype=dtype, device=device) - for i, b in enumerate(boxes) - ], - dim=0, - ) - if concat_boxes.shape[0] != 0: - rois = Tensor.cat(*[ids, concat_boxes], dim=1) - return rois + def __call__(self, x, boxes): + num_levels = len(self.poolers) + rois = self.convert_to_roi_format(boxes) + if rois: + if num_levels == 1: + return self.poolers[0](x[0], rois) - def __call__(self, x, boxes): - num_levels = len(self.poolers) - rois = self.convert_to_roi_format(boxes) - if rois: - if num_levels == 1: - return self.poolers[0](x[0], rois) + levels = self.map_levels(boxes) + results = [] + all_idxs = [] + for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): + # this is fine because no grad will flow through index + idx_in_level = (levels.numpy() == level).nonzero()[0] + if len(idx_in_level) > 0: + rois_per_level = tensor_gather(rois, idx_in_level) + pooler_output = pooler(per_level_feature, rois_per_level) + all_idxs.extend(idx_in_level) + results.append(pooler_output) - levels = self.map_levels(boxes) - results = [] - all_idxs = [] - for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): - # this is fine because no grad will flow through index - idx_in_level = (levels.numpy() == level).nonzero()[0] - if len(idx_in_level) > 0: - rois_per_level = tensor_gather(rois, idx_in_level) - pooler_output = pooler(per_level_feature, rois_per_level) - all_idxs.extend(idx_in_level) - results.append(pooler_output) - - return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])]) + return tensor_gather( + Tensor.cat(*results), + [ + x[0] + for x in sorted( + {i: idx for i, idx in enumerate(all_idxs)}.items(), + key=lambda x: x[1], + ) + ], + ) class FPNPredictor: - def __init__(self): - num_classes = 81 - representation_size = 1024 - self.cls_score = nn.Linear(representation_size, num_classes) - num_bbox_reg_classes = num_classes - self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) + def __init__(self): + num_classes = 81 + representation_size = 1024 + self.cls_score = nn.Linear(representation_size, num_classes) + num_bbox_reg_classes = num_classes + self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) - def __call__(self, x): - scores = self.cls_score(x) - bbox_deltas = self.bbox_pred(x) - return scores, bbox_deltas + def __call__(self, x): + scores = self.cls_score(x) + bbox_deltas = self.bbox_pred(x) + return scores, bbox_deltas class PostProcessor: - # Not used in training - def __init__( - self, - score_thresh=0.05, - nms=0.5, - detections_per_img=100, - box_coder=None, - cls_agnostic_bbox_reg=False - ): - self.score_thresh = score_thresh - self.nms = nms - self.detections_per_img = detections_per_img - if box_coder is None: - box_coder = BoxCoder(weights=(10., 10., 5., 5.)) - self.box_coder = box_coder - self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg - - def __call__(self, x, boxes): - class_logits, box_regression = x - class_prob = Tensor.softmax(class_logits, -1) - image_shapes = [box.size for box in boxes] - boxes_per_image = [len(box) for box in boxes] - concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0) - - if self.cls_agnostic_bbox_reg: - box_regression = box_regression[:, -4:] - proposals = self.box_coder.decode( - box_regression.reshape(sum(boxes_per_image), -1), concat_boxes - ) - if self.cls_agnostic_bbox_reg: - proposals = proposals.repeat([1, class_prob.shape[1]]) - num_classes = class_prob.shape[1] - proposals = proposals.unsqueeze(0) - class_prob = class_prob.unsqueeze(0) - results = [] - for prob, boxes_per_img, image_shape in zip( - class_prob, proposals, image_shapes - ): - boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) - boxlist = boxlist.clip_to_image(remove_empty=False) - boxlist = self.filter_results(boxlist, num_classes) - results.append(boxlist) - return results - - def prepare_boxlist(self, boxes, scores, image_shape): - boxes = boxes.reshape(-1, 4) - scores = scores.reshape(-1) - boxlist = BoxList(boxes, image_shape, mode="xyxy") - boxlist.add_field("scores", scores) - return boxlist - - def filter_results(self, boxlist, num_classes): - boxes = boxlist.bbox.reshape(-1, num_classes * 4) - scores = boxlist.get_field("scores").reshape(-1, num_classes) - - device = scores.device - result = [] - scores = scores.numpy() - boxes = boxes.numpy() - inds_all = scores > self.score_thresh - for j in range(1, num_classes): - inds = inds_all[:, j].nonzero()[0] - # This needs to be done in numpy because it can create empty arrays - scores_j = scores[inds, j] - boxes_j = boxes[inds, j * 4: (j + 1) * 4] - boxes_j = Tensor(boxes_j) - scores_j = Tensor(scores_j) - boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") - boxlist_for_class.add_field("scores", scores_j) - if len(boxlist_for_class): - boxlist_for_class = boxlist_nms( - boxlist_for_class, self.nms - ) - num_labels = len(boxlist_for_class) - boxlist_for_class.add_field( - "labels", Tensor.full((num_labels,), j, device=device) - ) - result.append(boxlist_for_class) - - result = cat_boxlist(result) - number_of_detections = len(result) - - if number_of_detections > self.detections_per_img > 0: - cls_scores = result.get_field("scores") - image_thresh, _ = topk(cls_scores, k=self.detections_per_img) - image_thresh = image_thresh.numpy()[-1] - keep = (cls_scores.numpy() >= image_thresh).nonzero()[0] - result = result[keep] - return result - - -class RoIBoxHead: - def __init__(self, in_channels): - self.feature_extractor = FPN2MLPFeatureExtractor(in_channels) - self.predictor = FPNPredictor() - self.post_processor = PostProcessor( + # Not used in training + def __init__( + self, score_thresh=0.05, nms=0.5, detections_per_img=100, - box_coder=BoxCoder(weights=(10., 10., 5., 5.)), - cls_agnostic_bbox_reg=False - ) + box_coder=None, + cls_agnostic_bbox_reg=False, + ): + self.score_thresh = score_thresh + self.nms = nms + self.detections_per_img = detections_per_img + if box_coder is None: + box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) + self.box_coder = box_coder + self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg - def __call__(self, features, proposals, targets=None): - x = self.feature_extractor(features, proposals) - class_logits, box_regression = self.predictor(x) - if not Tensor.training: - result = self.post_processor((class_logits, box_regression), proposals) - return x, result, {} + def __call__(self, x, boxes): + class_logits, box_regression = x + class_prob = Tensor.softmax(class_logits, -1) + image_shapes = [box.size for box in boxes] + boxes_per_image = [len(box) for box in boxes] + concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0) + + if self.cls_agnostic_bbox_reg: + box_regression = box_regression[:, -4:] + proposals = self.box_coder.decode( + box_regression.reshape(sum(boxes_per_image), -1), concat_boxes + ) + if self.cls_agnostic_bbox_reg: + proposals = proposals.repeat([1, class_prob.shape[1]]) + num_classes = class_prob.shape[1] + proposals = proposals.unsqueeze(0) + class_prob = class_prob.unsqueeze(0) + results = [] + for prob, boxes_per_img, image_shape in zip( + class_prob, proposals, image_shapes + ): + boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = self.filter_results(boxlist, num_classes) + results.append(boxlist) + return results + + def prepare_boxlist(self, boxes, scores, image_shape): + boxes = boxes.reshape(-1, 4) + scores = scores.reshape(-1) + boxlist = BoxList(boxes, image_shape, mode="xyxy") + boxlist.add_field("scores", scores) + return boxlist + + def filter_results(self, boxlist, num_classes): + boxes = boxlist.bbox.reshape(-1, num_classes * 4) + scores = boxlist.get_field("scores").reshape(-1, num_classes) + + device = scores.device + result = [] + scores = scores.numpy() + boxes = boxes.numpy() + inds_all = scores > self.score_thresh + for j in range(1, num_classes): + inds = inds_all[:, j].nonzero()[0] + # This needs to be done in numpy because it can create empty arrays + scores_j = scores[inds, j] + boxes_j = boxes[inds, j * 4 : (j + 1) * 4] + boxes_j = Tensor(boxes_j) + scores_j = Tensor(scores_j) + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + if len(boxlist_for_class): + boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", Tensor.full((num_labels,), j, device=device) + ) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + if number_of_detections > self.detections_per_img > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = topk(cls_scores, k=self.detections_per_img) + image_thresh = image_thresh.numpy()[-1] + keep = (cls_scores.numpy() >= image_thresh).nonzero()[0] + result = result[keep] + return result + + +class RoIBoxHead: + def __init__(self, in_channels): + self.feature_extractor = FPN2MLPFeatureExtractor(in_channels) + self.predictor = FPNPredictor() + self.post_processor = PostProcessor( + score_thresh=0.05, + nms=0.5, + detections_per_img=100, + box_coder=BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)), + cls_agnostic_bbox_reg=False, + ) + + def __call__(self, features, proposals, targets=None): + x = self.feature_extractor(features, proposals) + class_logits, box_regression = self.predictor(x) + if not Tensor.training: + result = self.post_processor((class_logits, box_regression), proposals) + return x, result, {} class MaskPostProcessor: - # Not used in loss calculation - def __call__(self, x, boxes): - mask_prob = x.sigmoid().numpy() - num_masks = x.shape[0] - labels = [bbox.get_field("labels") for bbox in boxes] - labels = Tensor.cat(*labels).numpy().astype(np.int32) - index = np.arange(num_masks) - mask_prob = mask_prob[index, labels][:, None] - boxes_per_image, cumsum = [], 0 - for box in boxes: - cumsum += len(box) - boxes_per_image.append(cumsum) - # using numpy here as Tensor.chunk doesnt have custom chunk sizes - mask_prob = np.split(mask_prob, boxes_per_image, axis=0) - results = [] - for prob, box in zip(mask_prob, boxes): - bbox = BoxList(box.bbox, box.size, mode="xyxy") - for field in box.fields(): - bbox.add_field(field, box.get_field(field)) - prob = Tensor(prob) - bbox.add_field("mask", prob) - results.append(bbox) + # Not used in loss calculation + def __call__(self, x, boxes): + mask_prob = x.sigmoid().numpy() + num_masks = x.shape[0] + labels = [bbox.get_field("labels") for bbox in boxes] + labels = Tensor.cat(*labels).numpy().astype(np.int32) + index = np.arange(num_masks) + mask_prob = mask_prob[index, labels][:, None] + boxes_per_image, cumsum = [], 0 + for box in boxes: + cumsum += len(box) + boxes_per_image.append(cumsum) + # using numpy here as Tensor.chunk doesnt have custom chunk sizes + mask_prob = np.split(mask_prob, boxes_per_image, axis=0) + results = [] + for prob, box in zip(mask_prob, boxes): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + prob = Tensor(prob) + bbox.add_field("mask", prob) + results.append(bbox) - return results + return results class Mask: - def __init__(self): - self.feature_extractor = MaskRCNNFPNFeatureExtractor() - self.predictor = MaskRCNNC4Predictor() - self.post_processor = MaskPostProcessor() + def __init__(self): + self.feature_extractor = MaskRCNNFPNFeatureExtractor() + self.predictor = MaskRCNNC4Predictor() + self.post_processor = MaskPostProcessor() - def __call__(self, features, proposals, targets=None): - x = self.feature_extractor(features, proposals) - if x: - mask_logits = self.predictor(x) - if not Tensor.training: - result = self.post_processor(mask_logits, proposals) - return x, result, {} - return x, [], {} + def __call__(self, features, proposals, targets=None): + x = self.feature_extractor(features, proposals) + if x: + mask_logits = self.predictor(x) + if not Tensor.training: + result = self.post_processor(mask_logits, proposals) + return x, result, {} + return x, [], {} class RoIHeads: - def __init__(self, in_channels): - self.box = RoIBoxHead(in_channels) - self.mask = Mask() + def __init__(self, in_channels): + self.box = RoIBoxHead(in_channels) + self.mask = Mask() - def __call__(self, features, proposals, targets=None): - x, detections, _ = self.box(features, proposals, targets) - x, detections, _ = self.mask(features, detections, targets) - return x, detections, {} + def __call__(self, features, proposals, targets=None): + x, detections, _ = self.box(features, proposals, targets) + x, detections, _ = self.mask(features, detections, targets) + return x, detections, {} class ImageList(object): - def __init__(self, tensors, image_sizes): - self.tensors = tensors - self.image_sizes = image_sizes + def __init__(self, tensors, image_sizes): + self.tensors = tensors + self.image_sizes = image_sizes - def to(self, *args, **kwargs): - cast_tensor = self.tensors.to(*args, **kwargs) - return ImageList(cast_tensor, self.image_sizes) + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) def to_image_list(tensors, size_divisible=32): - # Preprocessing - if isinstance(tensors, Tensor) and size_divisible > 0: - tensors = [tensors] + # Preprocessing + if isinstance(tensors, Tensor) and size_divisible > 0: + tensors = [tensors] - if isinstance(tensors, ImageList): - return tensors - elif isinstance(tensors, Tensor): - # single tensor shape can be inferred - assert tensors.ndim == 4 - image_sizes = [tensor.shape[-2:] for tensor in tensors] - return ImageList(tensors, image_sizes) - elif isinstance(tensors, (tuple, list)): - max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) - if size_divisible > 0: + if isinstance(tensors, ImageList): + return tensors + elif isinstance(tensors, Tensor): + # single tensor shape can be inferred + assert tensors.ndim == 4 + image_sizes = [tensor.shape[-2:] for tensor in tensors] + return ImageList(tensors, image_sizes) + elif isinstance(tensors, (tuple, list)): + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) + if size_divisible > 0: + stride = size_divisible + max_size = list(max_size) + max_size[1] = int(math.ceil(max_size[1] / stride) * stride) + max_size[2] = int(math.ceil(max_size[2] / stride) * stride) + max_size = tuple(max_size) - stride = size_divisible - max_size = list(max_size) - max_size[1] = int(math.ceil(max_size[1] / stride) * stride) - max_size[2] = int(math.ceil(max_size[2] / stride) * stride) - max_size = tuple(max_size) + batch_shape = (len(tensors),) + max_size + batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np) + for img, pad_img in zip(tensors, batched_imgs): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy() - batch_shape = (len(tensors),) + max_size - batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np) - for img, pad_img in zip(tensors, batched_imgs): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy() + batched_imgs = Tensor(batched_imgs) + image_sizes = [im.shape[-2:] for im in tensors] - batched_imgs = Tensor(batched_imgs) - image_sizes = [im.shape[-2:] for im in tensors] - - return ImageList(batched_imgs, image_sizes) - else: - raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) + return ImageList(batched_imgs, image_sizes) + else: + raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) class MaskRCNN: - def __init__(self, backbone: ResNet): - self.backbone = ResNetFPN(backbone, out_channels=256) - self.rpn = RPN(self.backbone.out_channels) - self.roi_heads = RoIHeads(self.backbone.out_channels) + def __init__(self, backbone: ResNet): + self.backbone = ResNetFPN(backbone, out_channels=256) + self.rpn = RPN(self.backbone.out_channels) + self.roi_heads = RoIHeads(self.backbone.out_channels) - def load_from_pretrained(self): - fn = Path('./') / "weights/maskrcnn.pt" - fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn) + def load_from_pretrained(self): + fn = Path("./") / "weights/maskrcnn.pt" + fetch( + "https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", + fn, + ) - state_dict = torch_load(fn)['model'] - loaded_keys = [] - for k, v in state_dict.items(): - if "module." in k: - k = k.replace("module.", "") - if "stem." in k: - k = k.replace("stem.", "") - if "fpn_inner" in k: - block_index = int(re.search(r"fpn_inner(\d+)", k).group(1)) - k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k) - if "fpn_layer" in k: - block_index = int(re.search(r"fpn_layer(\d+)", k).group(1)) - k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k) - loaded_keys.append(k) - get_child(self, k).assign(v.numpy()).realize() - return loaded_keys + state_dict = torch_load(fn)["model"] + loaded_keys = [] + for k, v in state_dict.items(): + if "module." in k: + k = k.replace("module.", "") + if "stem." in k: + k = k.replace("stem.", "") + if "fpn_inner" in k: + block_index = int(re.search(r"fpn_inner(\d+)", k).group(1)) + k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k) + if "fpn_layer" in k: + block_index = int(re.search(r"fpn_layer(\d+)", k).group(1)) + k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k) + loaded_keys.append(k) + get_child(self, k).assign(v.numpy()).realize() + return loaded_keys - def __call__(self, images): - images = to_image_list(images) - features = self.backbone(images.tensors) - proposals, _ = self.rpn(images, features) - x, result, _ = self.roi_heads(features, proposals) - return result + def __call__(self, images): + images = to_image_list(images) + features = self.backbone(images.tensors) + proposals, _ = self.rpn(images, features) + x, result, _ = self.roi_heads(features, proposals) + return result -if __name__ == '__main__': - resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True) - model = MaskRCNN(backbone=resnet) - model.load_from_pretrained() +if __name__ == "__main__": + resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True) + model = MaskRCNN(backbone=resnet) + model.load_from_pretrained() diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 517f1ec9e..2269ad03f 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -3,150 +3,229 @@ from tinygrad.tensor import Tensor from tinygrad.nn.state import torch_load from tinygrad.helpers import fetch, get_child + class BasicBlock: - expansion = 1 + expansion = 1 - def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): - assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64" - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = [] - if stride != 1 or in_planes != self.expansion*planes: - self.downsample = [ - nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ] + def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): + assert ( + groups == 1 and base_width == 64 + ), "BasicBlock only supports groups=1 and base_width=64" + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, stride=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = [] + if stride != 1 or in_planes != self.expansion * planes: + self.downsample = [ + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ] - def __call__(self, x): - out = self.bn1(self.conv1(x)).relu() - out = self.bn2(self.conv2(out)) - out = out + x.sequential(self.downsample) - out = out.relu() - return out + def __call__(self, x): + out = self.bn1(self.conv1(x)).relu() + out = self.bn2(self.conv2(out)) + out = out + x.sequential(self.downsample) + out = out.relu() + return out class Bottleneck: - # NOTE: stride_in_1x1=False, this is the v1.5 variant - expansion = 4 + # NOTE: stride_in_1x1=False, this is the v1.5 variant + expansion = 4 - def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64): - width = int(planes * (base_width / 64.0)) * groups - # NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1 - self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion*planes) - self.downsample = [] - if stride != 1 or in_planes != self.expansion*planes: - self.downsample = [ - nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion*planes) - ] + def __init__( + self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64 + ): + width = int(planes * (base_width / 64.0)) * groups + # NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1 + self.conv1 = nn.Conv2d( + in_planes, + width, + kernel_size=1, + stride=stride if stride_in_1x1 else 1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, + width, + kernel_size=3, + padding=1, + stride=1 if stride_in_1x1 else stride, + groups=groups, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d( + width, self.expansion * planes, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + self.downsample = [] + if stride != 1 or in_planes != self.expansion * planes: + self.downsample = [ + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ] + + def __call__(self, x): + out = self.bn1(self.conv1(x)).relu() + out = self.bn2(self.conv2(out)).relu() + out = self.bn3(self.conv3(out)) + out = out + x.sequential(self.downsample) + out = out.relu() + return out - def __call__(self, x): - out = self.bn1(self.conv1(x)).relu() - out = self.bn2(self.conv2(out)).relu() - out = self.bn3(self.conv3(out)) - out = out + x.sequential(self.downsample) - out = out.relu() - return out class ResNet: - def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False): - self.num = num - self.block = { - 18: BasicBlock, - 34: BasicBlock, - 50: Bottleneck, - 101: Bottleneck, - 152: Bottleneck - }[num] + def __init__( + self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False + ): + self.num = num + self.block = { + 18: BasicBlock, + 34: BasicBlock, + 50: Bottleneck, + 101: Bottleneck, + 152: Bottleneck, + }[num] - self.num_blocks = { - 18: [2,2,2,2], - 34: [3,4,6,3], - 50: [3,4,6,3], - 101: [3,4,23,3], - 152: [3,8,36,3] - }[num] + self.num_blocks = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[num] - self.in_planes = 64 + self.in_planes = 64 - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1) - self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1) - self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1) - self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1) - self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer( + self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1 + ) + self.layer2 = self._make_layer( + self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1 + ) + self.layer3 = self._make_layer( + self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1 + ) + self.layer4 = self._make_layer( + self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1 + ) + self.fc = ( + nn.Linear(512 * self.block.expansion, num_classes) + if num_classes is not None + else None + ) - def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1): - strides = [stride] + [1] * (num_blocks-1) - layers = [] - for stride in strides: - if block == Bottleneck: - layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width)) - else: - layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width)) - self.in_planes = planes * block.expansion - return layers + def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + if block == Bottleneck: + layers.append( + block( + self.in_planes, + planes, + stride, + stride_in_1x1, + self.groups, + self.base_width, + ) + ) + else: + layers.append( + block(self.in_planes, planes, stride, self.groups, self.base_width) + ) + self.in_planes = planes * block.expansion + return layers - def forward(self, x): - is_feature_only = self.fc is None - if is_feature_only: features = [] - out = self.bn1(self.conv1(x)).relu() - out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) - out = out.sequential(self.layer1) - if is_feature_only: features.append(out) - out = out.sequential(self.layer2) - if is_feature_only: features.append(out) - out = out.sequential(self.layer3) - if is_feature_only: features.append(out) - out = out.sequential(self.layer4) - if is_feature_only: features.append(out) - if not is_feature_only: - out = out.mean([2,3]) - out = self.fc(out).log_softmax() - return out - return features + def forward(self, x): + is_feature_only = self.fc is None + if is_feature_only: + features = [] + out = self.bn1(self.conv1(x)).relu() + out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2) + out = out.sequential(self.layer1) + if is_feature_only: + features.append(out) + out = out.sequential(self.layer2) + if is_feature_only: + features.append(out) + out = out.sequential(self.layer3) + if is_feature_only: + features.append(out) + out = out.sequential(self.layer4) + if is_feature_only: + features.append(out) + if not is_feature_only: + out = out.mean([2, 3]) + out = self.fc(out).log_softmax() + return out + return features - def __call__(self, x:Tensor) -> Tensor: - return self.forward(x) + def __call__(self, x: Tensor) -> Tensor: + return self.forward(x) - def load_from_pretrained(self): - # TODO replace with fake torch load + def load_from_pretrained(self): + # TODO replace with fake torch load - model_urls = { - (18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - (34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - (50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - (50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - (101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - (152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', - } + model_urls = { + (18, 1, 64): "https://download.pytorch.org/models/resnet18-5c106cde.pth", + (34, 1, 64): "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + (50, 1, 64): "https://download.pytorch.org/models/resnet50-19c8e357.pth", + ( + 50, + 32, + 4, + ): "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + (101, 1, 64): "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + (152, 1, 64): "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + } - self.url = model_urls[(self.num, self.groups, self.base_width)] - for k, v in torch_load(fetch(self.url)).items(): - obj: Tensor = get_child(self, k) - dat = v.detach().numpy() + self.url = model_urls[(self.num, self.groups, self.base_width)] + for k, v in torch_load(fetch(self.url)).items(): + obj: Tensor = get_child(self, k) + dat = v.detach().numpy() - if 'fc.' in k and obj.shape != dat.shape: - print("skipping fully connected layer") - continue # Skip FC if transfer learning + if "fc." in k and obj.shape != dat.shape: + print("skipping fully connected layer") + continue # Skip FC if transfer learning + + # TODO: remove or when #777 is merged + assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), ( + k, + obj.shape, + dat.shape, + ) + obj.assign(dat) - # TODO: remove or when #777 is merged - assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape) - obj.assign(dat) ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes) ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes) ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes) ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes) ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes) -ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4) \ No newline at end of file +ResNeXt50_32X4D = lambda num_classes=1000: ResNet( + 50, num_classes=num_classes, groups=32, width_per_group=4 +) diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 91d758ff5..2383953e3 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -4,233 +4,379 @@ import tinygrad.nn as nn from extra.models.resnet import ResNet import numpy as np + def nms(boxes, scores, thresh=0.5): - x1, y1, x2, y2 = np.rollaxis(boxes, 1) - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - to_process, keep = scores.argsort()[::-1], [] - while to_process.size > 0: - cur, to_process = to_process[0], to_process[1:] - keep.append(cur) - inter_x1 = np.maximum(x1[cur], x1[to_process]) - inter_y1 = np.maximum(y1[cur], y1[to_process]) - inter_x2 = np.minimum(x2[cur], x2[to_process]) - inter_y2 = np.minimum(y2[cur], y2[to_process]) - inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1) - iou = inter_area / (areas[cur] + areas[to_process] - inter_area) - to_process = to_process[np.where(iou <= thresh)[0]] - return keep + x1, y1, x2, y2 = np.rollaxis(boxes, 1) + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + to_process, keep = scores.argsort()[::-1], [] + while to_process.size > 0: + cur, to_process = to_process[0], to_process[1:] + keep.append(cur) + inter_x1 = np.maximum(x1[cur], x1[to_process]) + inter_y1 = np.maximum(y1[cur], y1[to_process]) + inter_x2 = np.minimum(x2[cur], x2[to_process]) + inter_y2 = np.minimum(y2[cur], y2[to_process]) + inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum( + 0, inter_y2 - inter_y1 + 1 + ) + iou = inter_area / (areas[cur] + areas[to_process] - inter_area) + to_process = to_process[np.where(iou <= thresh)[0]] + return keep + def decode_bbox(offsets, anchors): - dx, dy, dw, dh = np.rollaxis(offsets, 1) - widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1] - cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights - pred_cx, pred_cy = dx * widths + cx, dy * heights + cy - pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights - pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h - pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h - return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32) + dx, dy, dw, dh = np.rollaxis(offsets, 1) + widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1] + cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights + pred_cx, pred_cy = dx * widths + cx, dy * heights + cy + pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights + pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h + pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h + return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32) + def generate_anchors(input_size, grid_sizes, scales, aspect_ratios): - assert len(scales) == len(aspect_ratios) == len(grid_sizes) - anchors = [] - for s, ar, gs in zip(scales, aspect_ratios, grid_sizes): - s, ar = np.array(s), np.array(ar) - h_ratios = np.sqrt(ar) - w_ratios = 1 / h_ratios - ws = (w_ratios[:, None] * s[None, :]).reshape(-1) - hs = (h_ratios[:, None] * s[None, :]).reshape(-1) - base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round() - stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1] - shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h) - shifts_x = shifts_x.reshape(-1) - shifts_y = shifts_y.reshape(-1) - shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32) - anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4)) - return anchors + assert len(scales) == len(aspect_ratios) == len(grid_sizes) + anchors = [] + for s, ar, gs in zip(scales, aspect_ratios, grid_sizes): + s, ar = np.array(s), np.array(ar) + h_ratios = np.sqrt(ar) + w_ratios = 1 / h_ratios + ws = (w_ratios[:, None] * s[None, :]).reshape(-1) + hs = (h_ratios[:, None] * s[None, :]).reshape(-1) + base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round() + stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1] + shifts_x, shifts_y = np.meshgrid( + np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h + ) + shifts_x = shifts_x.reshape(-1) + shifts_y = shifts_y.reshape(-1) + shifts = np.stack( + [shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32 + ) + anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4)) + return anchors + class RetinaNet: - def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None): - assert isinstance(backbone, ResNet) - scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales - aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios - self.num_anchors, self.num_classes = num_anchors, num_classes - assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)) + def __init__( + self, + backbone: ResNet, + num_classes=264, + num_anchors=9, + scales=None, + aspect_ratios=None, + ): + assert isinstance(backbone, ResNet) + scales = ( + tuple( + (i, int(i * 2 ** (1 / 3)), int(i * 2 ** (2 / 3))) + for i in 2 ** np.arange(5, 10) + ) + if scales is None + else scales + ) + aspect_ratios = ( + ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios + ) + self.num_anchors, self.num_classes = num_anchors, num_classes + assert len(scales) == len(aspect_ratios) and all( + self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios) + ) - self.backbone = ResNetFPN(backbone) - self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes) - self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios) + self.backbone = ResNetFPN(backbone) + self.head = RetinaHead( + self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes + ) + self.anchor_gen = lambda input_size: generate_anchors( + input_size, + self.backbone.compute_grid_sizes(input_size), + scales, + aspect_ratios, + ) - def __call__(self, x): - return self.forward(x) - def forward(self, x): - return self.head(self.backbone(x)) + def __call__(self, x): + return self.forward(x) - def load_from_pretrained(self): - model_urls = { - (50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - (50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip", - } - self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)] - from torch.hub import load_state_dict_from_url - state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu') - state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict - for k, v in state_dict.items(): - obj = get_child(self, k) - dat = v.detach().numpy() - assert obj.shape == dat.shape, (k, obj.shape, dat.shape) - obj.assign(dat) + def forward(self, x): + return self.head(self.backbone(x)) - # predictions: (BS, (H1W1+...+HmWm)A, 4 + K) - def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5): - anchors = self.anchor_gen(input_size) - grid_sizes = self.backbone.compute_grid_sizes(input_size) - split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]) - detections = [] - for i, predictions_per_image in enumerate(predictions): - h, w = input_size if image_sizes is None else image_sizes[i] + def load_from_pretrained(self): + model_urls = { + ( + 50, + 1, + 64, + ): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + ( + 50, + 32, + 4, + ): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip", + } + self.url = model_urls[ + ( + self.backbone.body.num, + self.backbone.body.groups, + self.backbone.body.base_width, + ) + ] + from torch.hub import load_state_dict_from_url - predictions_per_image = np.split(predictions_per_image, split_idx) - offsets_per_image = [br[:, :4] for br in predictions_per_image] - scores_per_image = [cl[:, 4:] for cl in predictions_per_image] + state_dict = load_state_dict_from_url( + self.url, progress=True, map_location="cpu" + ) + state_dict = state_dict["model"] if "model" in state_dict.keys() else state_dict + for k, v in state_dict.items(): + obj = get_child(self, k) + dat = v.detach().numpy() + assert obj.shape == dat.shape, (k, obj.shape, dat.shape) + obj.assign(dat) - image_boxes, image_scores, image_labels = [], [], [] - for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors): - # remove low scoring boxes - scores_per_level = scores_per_level.flatten() - keep_idxs = scores_per_level > score_thresh - scores_per_level = scores_per_level[keep_idxs] + # predictions: (BS, (H1W1+...+HmWm)A, 4 + K) + def postprocess_detections( + self, + predictions, + input_size=(800, 800), + image_sizes=None, + orig_image_sizes=None, + score_thresh=0.05, + topk_candidates=1000, + nms_thresh=0.5, + ): + anchors = self.anchor_gen(input_size) + grid_sizes = self.backbone.compute_grid_sizes(input_size) + split_idx = np.cumsum( + [int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]] + ) + detections = [] + for i, predictions_per_image in enumerate(predictions): + h, w = input_size if image_sizes is None else image_sizes[i] - # keep topk - topk_idxs = np.where(keep_idxs)[0] - num_topk = min(len(topk_idxs), topk_candidates) - sort_idxs = scores_per_level.argsort()[-num_topk:][::-1] - topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs] + predictions_per_image = np.split(predictions_per_image, split_idx) + offsets_per_image = [br[:, :4] for br in predictions_per_image] + scores_per_image = [cl[:, 4:] for cl in predictions_per_image] - # bbox coords from offsets - anchor_idxs = topk_idxs // self.num_classes - labels_per_level = topk_idxs % self.num_classes - boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]) - # clip to image size - clipped_x = boxes_per_level[:, 0::2].clip(0, w) - clipped_y = boxes_per_level[:, 1::2].clip(0, h) - boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4) + image_boxes, image_scores, image_labels = [], [], [] + for offsets_per_level, scores_per_level, anchors_per_level in zip( + offsets_per_image, scores_per_image, anchors + ): + # remove low scoring boxes + scores_per_level = scores_per_level.flatten() + keep_idxs = scores_per_level > score_thresh + scores_per_level = scores_per_level[keep_idxs] - image_boxes.append(boxes_per_level) - image_scores.append(scores_per_level) - image_labels.append(labels_per_level) + # keep topk + topk_idxs = np.where(keep_idxs)[0] + num_topk = min(len(topk_idxs), topk_candidates) + sort_idxs = scores_per_level.argsort()[-num_topk:][::-1] + topk_idxs, scores_per_level = ( + topk_idxs[sort_idxs], + scores_per_level[sort_idxs], + ) - image_boxes = np.concatenate(image_boxes) - image_scores = np.concatenate(image_scores) - image_labels = np.concatenate(image_labels) + # bbox coords from offsets + anchor_idxs = topk_idxs // self.num_classes + labels_per_level = topk_idxs % self.num_classes + boxes_per_level = decode_bbox( + offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) + # clip to image size + clipped_x = boxes_per_level[:, 0::2].clip(0, w) + clipped_y = boxes_per_level[:, 1::2].clip(0, h) + boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape( + -1, 4 + ) - # nms for each class - keep_mask = np.zeros_like(image_scores, dtype=bool) - for class_id in np.unique(image_labels): - curr_indices = np.where(image_labels == class_id)[0] - curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh) - keep_mask[curr_indices[curr_keep_indices]] = True - keep = np.where(keep_mask)[0] - keep = keep[image_scores[keep].argsort()[::-1]] + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) - # resize bboxes back to original size - image_boxes = image_boxes[keep] - if orig_image_sizes is not None: - resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w - resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h - image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4) - # xywh format - image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1) + image_boxes = np.concatenate(image_boxes) + image_scores = np.concatenate(image_scores) + image_labels = np.concatenate(image_labels) + + # nms for each class + keep_mask = np.zeros_like(image_scores, dtype=bool) + for class_id in np.unique(image_labels): + curr_indices = np.where(image_labels == class_id)[0] + curr_keep_indices = nms( + image_boxes[curr_indices], image_scores[curr_indices], nms_thresh + ) + keep_mask[curr_indices[curr_keep_indices]] = True + keep = np.where(keep_mask)[0] + keep = keep[image_scores[keep].argsort()[::-1]] + + # resize bboxes back to original size + image_boxes = image_boxes[keep] + if orig_image_sizes is not None: + resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w + resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h + image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4) + # xywh format + image_boxes = np.concatenate( + [image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1 + ) + + detections.append( + { + "boxes": image_boxes, + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + return detections - detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]}) - return detections class ClassificationHead: - def __init__(self, in_channels, num_anchors, num_classes): - self.num_classes = num_classes - self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)]) - self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1) - def __call__(self, x): - out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x] - return out[0].cat(*out[1:], dim=1).sigmoid() + def __init__(self, in_channels, num_anchors, num_classes): + self.num_classes = num_classes + self.conv = flatten( + [ + ( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + lambda x: x.relu(), + ) + for _ in range(4) + ] + ) + self.cls_logits = nn.Conv2d( + in_channels, num_anchors * num_classes, kernel_size=3, padding=1 + ) + + def __call__(self, x): + out = [ + self.cls_logits(feat.sequential(self.conv)) + .permute(0, 2, 3, 1) + .reshape(feat.shape[0], -1, self.num_classes) + for feat in x + ] + return out[0].cat(*out[1:], dim=1).sigmoid() + class RegressionHead: - def __init__(self, in_channels, num_anchors): - self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)]) - self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1) - def __call__(self, x): - out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x] - return out[0].cat(*out[1:], dim=1) + def __init__(self, in_channels, num_anchors): + self.conv = flatten( + [ + ( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + lambda x: x.relu(), + ) + for _ in range(4) + ] + ) + self.bbox_reg = nn.Conv2d( + in_channels, num_anchors * 4, kernel_size=3, padding=1 + ) + + def __call__(self, x): + out = [ + self.bbox_reg(feat.sequential(self.conv)) + .permute(0, 2, 3, 1) + .reshape(feat.shape[0], -1, 4) + for feat in x + ] + return out[0].cat(*out[1:], dim=1) + class RetinaHead: - def __init__(self, in_channels, num_anchors, num_classes): - self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes) - self.regression_head = RegressionHead(in_channels, num_anchors) - def __call__(self, x): - pred_bbox, pred_class = self.regression_head(x), self.classification_head(x) - out = pred_bbox.cat(pred_class, dim=-1) - return out + def __init__(self, in_channels, num_anchors, num_classes): + self.classification_head = ClassificationHead( + in_channels, num_anchors, num_classes + ) + self.regression_head = RegressionHead(in_channels, num_anchors) + + def __call__(self, x): + pred_bbox, pred_class = self.regression_head(x), self.classification_head(x) + out = pred_bbox.cat(pred_class, dim=-1) + return out + class ResNetFPN: - def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]): - self.out_channels = out_channels - self.body = resnet - in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers] - self.fpn = FPN(in_channels_list, out_channels) + def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]): + self.out_channels = out_channels + self.body = resnet + in_channels_list = [ + (self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers + ] + self.fpn = FPN(in_channels_list, out_channels) - # this is needed to decouple inference from postprocessing (anchors generation) - def compute_grid_sizes(self, input_size): - return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None]) + # this is needed to decouple inference from postprocessing (anchors generation) + def compute_grid_sizes(self, input_size): + return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None]) + + def __call__(self, x): + out = self.body.bn1(self.body.conv1(x)).relu() + out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2) + out = out.sequential(self.body.layer1) + p3 = out.sequential(self.body.layer2) + p4 = p3.sequential(self.body.layer3) + p5 = p4.sequential(self.body.layer4) + return self.fpn([p3, p4, p5]) - def __call__(self, x): - out = self.body.bn1(self.body.conv1(x)).relu() - out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) - out = out.sequential(self.body.layer1) - p3 = out.sequential(self.body.layer2) - p4 = p3.sequential(self.body.layer3) - p5 = p4.sequential(self.body.layer4) - return self.fpn([p3, p4, p5]) class ExtraFPNBlock: - def __init__(self, in_channels, out_channels): - self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) - self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) - self.use_P5 = in_channels == out_channels + def __init__(self, in_channels, out_channels): + self.p6 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + self.p7 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + self.use_P5 = in_channels == out_channels + + def __call__(self, p, c): + p5, c5 = p[-1], c[-1] + x = p5 if self.use_P5 else c5 + p6 = self.p6(x) + p7 = self.p7(p6.relu()) + p.extend([p6, p7]) + return p - def __call__(self, p, c): - p5, c5 = p[-1], c[-1] - x = p5 if self.use_P5 else c5 - p6 = self.p6(x) - p7 = self.p7(p6.relu()) - p.extend([p6, p7]) - return p class FPN: - def __init__(self, in_channels_list, out_channels, extra_blocks=None): - self.inner_blocks, self.layer_blocks = [], [] - for in_channels in in_channels_list: - self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) - self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) - self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks + def __init__(self, in_channels_list, out_channels, extra_blocks=None): + self.inner_blocks, self.layer_blocks = [], [] + for in_channels in in_channels_list: + self.inner_blocks.append( + nn.Conv2d(in_channels, out_channels, kernel_size=1) + ) + self.layer_blocks.append( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + ) + self.extra_blocks = ( + ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks + ) - def __call__(self, x): - last_inner = self.inner_blocks[-1](x[-1]) - results = [self.layer_blocks[-1](last_inner)] - for idx in range(len(x) - 2, -1, -1): - inner_lateral = self.inner_blocks[idx](x[idx]) + def __call__(self, x): + last_inner = self.inner_blocks[-1](x[-1]) + results = [self.layer_blocks[-1](last_inner)] + for idx in range(len(x) - 2, -1, -1): + inner_lateral = self.inner_blocks[idx](x[idx]) - # upsample to inner_lateral's shape - (ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2] - eh, ew = math.ceil(oh / ih), math.ceil(ow / iw) - inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow] + # upsample to inner_lateral's shape + (ih, iw), (oh, ow), prefix = ( + last_inner.shape[-2:], + inner_lateral.shape[-2:], + last_inner.shape[:-2], + ) + eh, ew = math.ceil(oh / ih), math.ceil(ow / iw) + inner_top_down = ( + last_inner.reshape(*prefix, ih, 1, iw, 1) + .expand(*prefix, ih, eh, iw, ew) + .reshape(*prefix, ih * eh, iw * ew)[:, :, :oh, :ow] + ) + + last_inner = inner_lateral + inner_top_down + results.insert(0, self.layer_blocks[idx](last_inner)) + if self.extra_blocks is not None: + results = self.extra_blocks(results, x) + return results - last_inner = inner_lateral + inner_top_down - results.insert(0, self.layer_blocks[idx](last_inner)) - if self.extra_blocks is not None: - results = self.extra_blocks(results, x) - return results if __name__ == "__main__": - from extra.models.resnet import ResNeXt50_32X4D - backbone = ResNeXt50_32X4D() - retina = RetinaNet(backbone) - retina.load_from_pretrained() + from extra.models.resnet import ResNeXt50_32X4D + + backbone = ResNeXt50_32X4D() + retina = RetinaNet(backbone) + retina.load_from_pretrained() diff --git a/extra/models/rnnt.py b/extra/models/rnnt.py index 589ac75c0..765c91b4d 100644 --- a/extra/models/rnnt.py +++ b/extra/models/rnnt.py @@ -7,196 +7,278 @@ from pathlib import Path class RNNT: - def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32): - self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout) - self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout) - self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout) + def __init__( + self, + input_features=240, + vocab_size=29, + enc_hidden_size=1024, + pred_hidden_size=320, + joint_hidden_size=512, + pre_enc_layers=2, + post_enc_layers=3, + pred_layers=2, + stack_time_factor=2, + dropout=0.32, + ): + self.encoder = Encoder( + input_features, + enc_hidden_size, + pre_enc_layers, + post_enc_layers, + stack_time_factor, + dropout, + ) + self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout) + self.joint = Joint( + vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout + ) - @TinyJit - def __call__(self, x, y, hc=None): - f, _ = self.encoder(x, None) - g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False)) - out = self.joint(f, g) - return out.realize() + @TinyJit + def __call__(self, x, y, hc=None): + f, _ = self.encoder(x, None) + g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False)) + out = self.joint(f, g) + return out.realize() - def decode(self, x, x_lens): - logits, logit_lens = self.encoder(x, x_lens) - outputs = [] - for b in range(logits.shape[0]): - inseq = logits[b, :, :].unsqueeze(1) - logit_len = logit_lens[b] - seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item())) - outputs.append(seq) - return outputs + def decode(self, x, x_lens): + logits, logit_lens = self.encoder(x, x_lens) + outputs = [] + for b in range(logits.shape[0]): + inseq = logits[b, :, :].unsqueeze(1) + logit_len = logit_lens[b] + seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item())) + outputs.append(seq) + return outputs - def _greedy_decode(self, logits, logit_len): - hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False) - labels = [] - label = Tensor.zeros(1, 1, requires_grad=False) - mask = Tensor.zeros(1, requires_grad=False) - for time_idx in range(logit_len): - logit = logits[time_idx, :, :].unsqueeze(0) - not_blank = True - added = 0 - while not_blank and added < 30: - if len(labels) > 0: - mask = (mask + 1).clip(0, 1) - label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1 - jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask) - k = jhc[0, 0, :29].argmax(axis=0).numpy() - not_blank = k != 28 - if not_blank: - labels.append(k) - hc = jhc[:, :, 29:] + 1 - 1 - added += 1 - return labels + def _greedy_decode(self, logits, logit_len): + hc = Tensor.zeros( + self.prediction.rnn.layers, + 2, + self.prediction.hidden_size, + requires_grad=False, + ) + labels = [] + label = Tensor.zeros(1, 1, requires_grad=False) + mask = Tensor.zeros(1, requires_grad=False) + for time_idx in range(logit_len): + logit = logits[time_idx, :, :].unsqueeze(0) + not_blank = True + added = 0 + while not_blank and added < 30: + if len(labels) > 0: + mask = (mask + 1).clip(0, 1) + label = ( + Tensor( + [[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], + requires_grad=False, + ) + + 1 + - 1 + ) + jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask) + k = jhc[0, 0, :29].argmax(axis=0).numpy() + not_blank = k != 28 + if not_blank: + labels.append(k) + hc = jhc[:, :, 29:] + 1 - 1 + added += 1 + return labels - @TinyJit - def _pred_joint(self, logit, label, hc, mask): - g, hc = self.prediction(label, hc, mask) - j = self.joint(logit, g)[0] - j = j.pad(((0, 1), (0, 1), (0, 0))) - out = j.cat(hc, dim=2) - return out.realize() + @TinyJit + def _pred_joint(self, logit, label, hc, mask): + g, hc = self.prediction(label, hc, mask) + j = self.joint(logit, g)[0] + j = j.pad(((0, 1), (0, 1), (0, 0))) + out = j.cat(hc, dim=2) + return out.realize() - def load_from_pretrained(self): - fn = Path(__file__).parents[1] / "weights/rnnt.pt" - fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn) + def load_from_pretrained(self): + fn = Path(__file__).parents[1] / "weights/rnnt.pt" + fetch( + "https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", + fn, + ) - import torch - with open(fn, "rb") as f: - state_dict = torch.load(f, map_location="cpu")["state_dict"] + import torch - # encoder - for i in range(2): - self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()) - self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()) - self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()) - self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()) - for i in range(3): - self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()) - self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()) - self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()) - self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()) + with open(fn, "rb") as f: + state_dict = torch.load(f, map_location="cpu")["state_dict"] - # prediction - self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy()) - for i in range(2): - self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()) - self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()) - self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()) - self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()) + # encoder + for i in range(2): + self.encoder.pre_rnn.cells[i].weights_ih.assign( + state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy() + ) + self.encoder.pre_rnn.cells[i].weights_hh.assign( + state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy() + ) + self.encoder.pre_rnn.cells[i].bias_ih.assign( + state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy() + ) + self.encoder.pre_rnn.cells[i].bias_hh.assign( + state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy() + ) + for i in range(3): + self.encoder.post_rnn.cells[i].weights_ih.assign( + state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy() + ) + self.encoder.post_rnn.cells[i].weights_hh.assign( + state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy() + ) + self.encoder.post_rnn.cells[i].bias_ih.assign( + state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy() + ) + self.encoder.post_rnn.cells[i].bias_hh.assign( + state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy() + ) - # joint - self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy()) - self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy()) - self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy()) - self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy()) + # prediction + self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy()) + for i in range(2): + self.prediction.rnn.cells[i].weights_ih.assign( + state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy() + ) + self.prediction.rnn.cells[i].weights_hh.assign( + state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy() + ) + self.prediction.rnn.cells[i].bias_ih.assign( + state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy() + ) + self.prediction.rnn.cells[i].bias_hh.assign( + state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy() + ) + + # joint + self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy()) + self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy()) + self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy()) + self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy()) class LSTMCell: - def __init__(self, input_size, hidden_size, dropout): - self.dropout = dropout + def __init__(self, input_size, hidden_size, dropout): + self.dropout = dropout - self.weights_ih = Tensor.uniform(hidden_size * 4, input_size) - self.bias_ih = Tensor.uniform(hidden_size * 4) - self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size) - self.bias_hh = Tensor.uniform(hidden_size * 4) + self.weights_ih = Tensor.uniform(hidden_size * 4, input_size) + self.bias_ih = Tensor.uniform(hidden_size * 4) + self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size) + self.bias_hh = Tensor.uniform(hidden_size * 4) - def __call__(self, x, hc): - gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh) + def __call__(self, x, hc): + gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[: x.shape[0]].linear( + self.weights_hh.T, self.bias_hh + ) - i, f, g, o = gates.chunk(4, 1) - i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid() + i, f, g, o = gates.chunk(4, 1) + i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid() - c = (f * hc[x.shape[0]:]) + (i * g) - h = (o * c.tanh()).dropout(self.dropout) + c = (f * hc[x.shape[0] :]) + (i * g) + h = (o * c.tanh()).dropout(self.dropout) - return Tensor.cat(h, c).realize() + return Tensor.cat(h, c).realize() class LSTM: - def __init__(self, input_size, hidden_size, layers, dropout): - self.input_size = input_size - self.hidden_size = hidden_size - self.layers = layers + def __init__(self, input_size, hidden_size, layers, dropout): + self.input_size = input_size + self.hidden_size = hidden_size + self.layers = layers - self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)] + self.cells = [ + LSTMCell(input_size, hidden_size, dropout) + if i == 0 + else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) + for i in range(layers) + ] - def __call__(self, x, hc): - @TinyJit - def _do_step(x_, hc_): - return self.do_step(x_, hc_) + def __call__(self, x, hc): + @TinyJit + def _do_step(x_, hc_): + return self.do_step(x_, hc_) - if hc is None: - hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False) + if hc is None: + hc = Tensor.zeros( + self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False + ) - output = None - for t in range(x.shape[0]): - hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this? - if output is None: - output = hc[-1:, :x.shape[1]] - else: - output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize() + output = None + for t in range(x.shape[0]): + hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this? + if output is None: + output = hc[-1:, : x.shape[1]] + else: + output = output.cat(hc[-1:, : x.shape[1]], dim=0).realize() - return output, hc + return output, hc - def do_step(self, x, hc): - new_hc = [x] - for i, cell in enumerate(self.cells): - new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i])) - return Tensor.stack(new_hc[1:]).realize() + def do_step(self, x, hc): + new_hc = [x] + for i, cell in enumerate(self.cells): + new_hc.append(cell(new_hc[i][: x.shape[0]], hc[i])) + return Tensor.stack(new_hc[1:]).realize() class StackTime: - def __init__(self, factor): - self.factor = factor + def __init__(self, factor): + self.factor = factor - def __call__(self, x, x_lens): - x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0))) - x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor) - return x, x_lens / self.factor if x_lens is not None else None + def __call__(self, x, x_lens): + x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0))) + x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor) + return x, x_lens / self.factor if x_lens is not None else None class Encoder: - def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout): - self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout) - self.stack_time = StackTime(stack_time_factor) - self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout) + def __init__( + self, + input_size, + hidden_size, + pre_layers, + post_layers, + stack_time_factor, + dropout, + ): + self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout) + self.stack_time = StackTime(stack_time_factor) + self.post_rnn = LSTM( + stack_time_factor * hidden_size, hidden_size, post_layers, dropout + ) - def __call__(self, x, x_lens): - x, _ = self.pre_rnn(x, None) - x, x_lens = self.stack_time(x, x_lens) - x, _ = self.post_rnn(x, None) - return x.transpose(0, 1), x_lens + def __call__(self, x, x_lens): + x, _ = self.pre_rnn(x, None) + x, x_lens = self.stack_time(x, x_lens) + x, _ = self.post_rnn(x, None) + return x.transpose(0, 1), x_lens class Prediction: - def __init__(self, vocab_size, hidden_size, layers, dropout): - self.hidden_size = hidden_size + def __init__(self, vocab_size, hidden_size, layers, dropout): + self.hidden_size = hidden_size - self.emb = Embedding(vocab_size - 1, hidden_size) - self.rnn = LSTM(hidden_size, hidden_size, layers, dropout) + self.emb = Embedding(vocab_size - 1, hidden_size) + self.rnn = LSTM(hidden_size, hidden_size, layers, dropout) - def __call__(self, x, hc, m): - emb = self.emb(x) * m - x_, hc = self.rnn(emb.transpose(0, 1), hc) - return x_.transpose(0, 1), hc + def __call__(self, x, hc, m): + emb = self.emb(x) * m + x_, hc = self.rnn(emb.transpose(0, 1), hc) + return x_.transpose(0, 1), hc class Joint: - def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout): - self.dropout = dropout + def __init__( + self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout + ): + self.dropout = dropout - self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size) - self.l2 = Linear(joint_hidden_size, vocab_size) + self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size) + self.l2 = Linear(joint_hidden_size, vocab_size) - def __call__(self, f, g): - (_, T, H), (B, U, H2) = f.shape, g.shape - f = f.unsqueeze(2).expand(B, T, U, H) - g = g.unsqueeze(1).expand(B, T, U, H2) + def __call__(self, f, g): + (_, T, H), (B, U, H2) = f.shape, g.shape + f = f.unsqueeze(2).expand(B, T, U, H) + g = g.unsqueeze(1).expand(B, T, U, H2) - inp = f.cat(g, dim=3) - t = self.l1(inp).relu() - t = t.dropout(self.dropout) - return self.l2(t) + inp = f.cat(g, dim=3) + t = self.l1(inp).relu() + t = t.dropout(self.dropout) + return self.l2(t) diff --git a/extra/models/transformer.py b/extra/models/transformer.py index 283e14b47..de08c0f62 100644 --- a/extra/models/transformer.py +++ b/extra/models/transformer.py @@ -1,64 +1,104 @@ import numpy as np from tinygrad.tensor import Tensor + class TransformerBlock: - def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1): - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + def __init__( + self, + embed_dim, + num_heads, + ff_dim, + prenorm=False, + act=lambda x: x.relu(), + dropout=0.1, + ): + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - self.num_heads = num_heads - self.head_size = embed_dim // num_heads - self.prenorm, self.act = prenorm, act - self.dropout = dropout + self.num_heads = num_heads + self.head_size = embed_dim // num_heads + self.prenorm, self.act = prenorm, act + self.dropout = dropout - self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) - self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) - self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) + self.query = ( + Tensor.scaled_uniform(embed_dim, embed_dim), + Tensor.zeros(embed_dim), + ) + self.key = ( + Tensor.scaled_uniform(embed_dim, embed_dim), + Tensor.zeros(embed_dim), + ) + self.value = ( + Tensor.scaled_uniform(embed_dim, embed_dim), + Tensor.zeros(embed_dim), + ) - self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) + self.out = ( + Tensor.scaled_uniform(embed_dim, embed_dim), + Tensor.zeros(embed_dim), + ) - self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim)) - self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim)) + self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim)) + self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim)) - self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) - self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) + self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) + self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) - def attn(self, x): - # x: (bs, time, embed_dim) -> (bs, time, embed_dim) - query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]] - attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2) - return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out) + def attn(self, x): + # x: (bs, time, embed_dim) -> (bs, time, embed_dim) + query, key, value = [ + x.linear(*y) + .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) + .transpose(1, 2) + for y in [self.query, self.key, self.value] + ] + attention = Tensor.scaled_dot_product_attention(query, key, value).transpose( + 1, 2 + ) + return attention.reshape( + shape=(x.shape[0], -1, self.num_heads * self.head_size) + ).linear(*self.out) + + def __call__(self, x): + if self.prenorm: + x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout) + x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear( + *self.ff2 + ).dropout(self.dropout) + else: + x = x + self.attn(x).dropout(self.dropout) + x = x.layernorm().linear(*self.ln1) + x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout( + self.dropout + ) + x = x.layernorm().linear(*self.ln2) + return x - def __call__(self, x): - if self.prenorm: - x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout) - x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout) - else: - x = x + self.attn(x).dropout(self.dropout) - x = x.layernorm().linear(*self.ln1) - x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout) - x = x.layernorm().linear(*self.ln2) - return x class Transformer: - def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim): - self.maxlen, self.syms = maxlen, syms - self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False) - self.tbs = [] - for i in range(layers): - self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim)) - self.final = Tensor.scaled_uniform(embed_dim, syms) + def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim): + self.maxlen, self.syms = maxlen, syms + self.embed = Tensor.scaled_uniform( + maxlen + syms, embed_dim, requires_grad=False + ) + self.tbs = [] + for i in range(layers): + self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim)) + self.final = Tensor.scaled_uniform(embed_dim, syms) - def forward(self, x): - bs = x.shape[0] - xnp = x.numpy().astype(np.int32) - onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32) - for i in range(x.shape[1]): - onehot[range(bs), i, i] = 1 - onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1 - onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms) - - x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1)) - x = x.sequential(self.tbs) - x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax() - return x.reshape(shape=(bs, -1, x.shape[-1])) + def forward(self, x): + bs = x.shape[0] + xnp = x.numpy().astype(np.int32) + onehot = np.zeros((bs, x.shape[1], self.maxlen + self.syms), dtype=np.float32) + for i in range(x.shape[1]): + onehot[range(bs), i, i] = 1 + onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1 + onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms) + x = ( + Tensor(onehot, device=x.device) + .dot(self.embed) + .reshape(shape=(bs, x.shape[1], -1)) + ) + x = x.sequential(self.tbs) + x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax() + return x.reshape(shape=(bs, -1, x.shape[-1])) diff --git a/extra/models/unet3d.py b/extra/models/unet3d.py index 1b2558c87..aeb016f4c 100644 --- a/extra/models/unet3d.py +++ b/extra/models/unet3d.py @@ -4,56 +4,100 @@ from tinygrad import nn from tinygrad.tensor import Tensor from tinygrad.helpers import fetch, get_child -class DownsampleBlock: - def __init__(self, c0, c1, stride=2): - self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] - self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] - def __call__(self, x): - return x.sequential(self.conv1).sequential(self.conv2) +class DownsampleBlock: + def __init__(self, c0, c1, stride=2): + self.conv1 = [ + nn.Conv2d( + c0, + c1, + kernel_size=(3, 3, 3), + stride=stride, + padding=(1, 1, 1, 1, 1, 1), + bias=False, + ), + nn.InstanceNorm(c1), + Tensor.relu, + ] + self.conv2 = [ + nn.Conv2d( + c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False + ), + nn.InstanceNorm(c1), + Tensor.relu, + ] + + def __call__(self, x): + return x.sequential(self.conv1).sequential(self.conv2) + class UpsampleBlock: - def __init__(self, c0, c1): - self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)] - self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] - self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] + def __init__(self, c0, c1): + self.upsample_conv = [ + nn.ConvTranspose2d(c0, c1, kernel_size=(2, 2, 2), stride=2) + ] + self.conv1 = [ + nn.Conv2d( + 2 * c1, + c1, + kernel_size=(3, 3, 3), + padding=(1, 1, 1, 1, 1, 1), + bias=False, + ), + nn.InstanceNorm(c1), + Tensor.relu, + ] + self.conv2 = [ + nn.Conv2d( + c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False + ), + nn.InstanceNorm(c1), + Tensor.relu, + ] + + def __call__(self, x, skip): + x = x.sequential(self.upsample_conv) + x = Tensor.cat(x, skip, dim=1) + return x.sequential(self.conv1).sequential(self.conv2) - def __call__(self, x, skip): - x = x.sequential(self.upsample_conv) - x = Tensor.cat(x, skip, dim=1) - return x.sequential(self.conv1).sequential(self.conv2) class UNet3D: - def __init__(self, in_channels=1, n_class=3): - filters = [32, 64, 128, 256, 320] - inp, out = filters[:-1], filters[1:] - self.input_block = DownsampleBlock(in_channels, filters[0], stride=1) - self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)] - self.bottleneck = DownsampleBlock(filters[-1], filters[-1]) - self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])] - self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))} + def __init__(self, in_channels=1, n_class=3): + filters = [32, 64, 128, 256, 320] + inp, out = filters[:-1], filters[1:] + self.input_block = DownsampleBlock(in_channels, filters[0], stride=1) + self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)] + self.bottleneck = DownsampleBlock(filters[-1], filters[-1]) + self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [ + UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1]) + ] + self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))} - def __call__(self, x): - x = self.input_block(x) - outputs = [x] - for downsample in self.downsample: - x = downsample(x) - outputs.append(x) - x = self.bottleneck(x) - for upsample, skip in zip(self.upsample, outputs[::-1]): - x = upsample(x, skip) - x = self.output["conv"](x) - return x + def __call__(self, x): + x = self.input_block(x) + outputs = [x] + for downsample in self.downsample: + x = downsample(x) + outputs.append(x) + x = self.bottleneck(x) + for upsample, skip in zip(self.upsample, outputs[::-1]): + x = upsample(x, skip) + x = self.output["conv"](x) + return x + + def load_from_pretrained(self): + fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt" + fetch( + "https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", + fn, + ) + state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict() + for k, v in state_dict.items(): + obj = get_child(self, k) + assert obj.shape == v.shape, (k, obj.shape, v.shape) + obj.assign(v.numpy()) - def load_from_pretrained(self): - fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt" - fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn) - state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict() - for k, v in state_dict.items(): - obj = get_child(self, k) - assert obj.shape == v.shape, (k, obj.shape, v.shape) - obj.assign(v.numpy()) if __name__ == "__main__": - mdl = UNet3D() - mdl.load_from_pretrained() + mdl = UNet3D() + mdl.load_from_pretrained() diff --git a/extra/models/vit.py b/extra/models/vit.py index a46570847..70cf5aa58 100644 --- a/extra/models/vit.py +++ b/extra/models/vit.py @@ -3,71 +3,126 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import fetch from extra.models.transformer import TransformerBlock + class ViT: - def __init__(self, layers=12, embed_dim=192, num_heads=3): - self.embedding = (Tensor.uniform(embed_dim, 3, 16, 16), Tensor.zeros(embed_dim)) - self.embed_dim = embed_dim - self.cls = Tensor.ones(1, 1, embed_dim) - self.pos_embedding = Tensor.ones(1, 197, embed_dim) - self.tbs = [ - TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, ff_dim=embed_dim*4, - prenorm=True, act=lambda x: x.gelu()) - for i in range(layers)] - self.encoder_norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim)) - self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000)) + def __init__(self, layers=12, embed_dim=192, num_heads=3): + self.embedding = (Tensor.uniform(embed_dim, 3, 16, 16), Tensor.zeros(embed_dim)) + self.embed_dim = embed_dim + self.cls = Tensor.ones(1, 1, embed_dim) + self.pos_embedding = Tensor.ones(1, 197, embed_dim) + self.tbs = [ + TransformerBlock( + embed_dim=embed_dim, + num_heads=num_heads, + ff_dim=embed_dim * 4, + prenorm=True, + act=lambda x: x.gelu(), + ) + for i in range(layers) + ] + self.encoder_norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim)) + self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000)) - def patch_embed(self, x): - x = x.conv2d(*self.embedding, stride=16) - x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0,2,1)) - return x + def patch_embed(self, x): + x = x.conv2d(*self.embedding, stride=16) + x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0, 2, 1)) + return x - def forward(self, x): - ce = self.cls.add(Tensor.zeros(x.shape[0],1,1)) - pe = self.patch_embed(x) - x = ce.cat(pe, dim=1) - x = x.add(self.pos_embedding).sequential(self.tbs) - x = x.layernorm().linear(*self.encoder_norm) - return x[:, 0].linear(*self.head) + def forward(self, x): + ce = self.cls.add(Tensor.zeros(x.shape[0], 1, 1)) + pe = self.patch_embed(x) + x = ce.cat(pe, dim=1) + x = x.add(self.pos_embedding).sequential(self.tbs) + x = x.layernorm().linear(*self.encoder_norm) + return x[:, 0].linear(*self.head) - def load_from_pretrained(m): - # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - if m.embed_dim == 192: - url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz" - elif m.embed_dim == 768: - url = "https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" - else: - raise Exception("no pretrained weights for configuration") - dat = np.load(fetch(url)) + def load_from_pretrained(m): + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + if m.embed_dim == 192: + url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz" + elif m.embed_dim == 768: + url = "https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" + else: + raise Exception("no pretrained weights for configuration") + dat = np.load(fetch(url)) - #for x in dat.keys(): - # print(x, dat[x].shape, dat[x].dtype) + # for x in dat.keys(): + # print(x, dat[x].shape, dat[x].dtype) - m.embedding[0].assign(np.transpose(dat['embedding/kernel'], (3,2,0,1))) - m.embedding[1].assign(dat['embedding/bias']) + m.embedding[0].assign(np.transpose(dat["embedding/kernel"], (3, 2, 0, 1))) + m.embedding[1].assign(dat["embedding/bias"]) - m.cls.assign(dat['cls']) + m.cls.assign(dat["cls"]) - m.head[0].assign(dat['head/kernel']) - m.head[1].assign(dat['head/bias']) + m.head[0].assign(dat["head/kernel"]) + m.head[1].assign(dat["head/bias"]) - m.pos_embedding.assign(dat['Transformer/posembed_input/pos_embedding']) - m.encoder_norm[0].assign(dat['Transformer/encoder_norm/scale']) - m.encoder_norm[1].assign(dat['Transformer/encoder_norm/bias']) + m.pos_embedding.assign(dat["Transformer/posembed_input/pos_embedding"]) + m.encoder_norm[0].assign(dat["Transformer/encoder_norm/scale"]) + m.encoder_norm[1].assign(dat["Transformer/encoder_norm/bias"]) - for i in range(12): - m.tbs[i].query[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(m.embed_dim, m.embed_dim)) - m.tbs[i].query[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(m.embed_dim)) - m.tbs[i].key[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(m.embed_dim, m.embed_dim)) - m.tbs[i].key[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(m.embed_dim)) - m.tbs[i].value[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(m.embed_dim, m.embed_dim)) - m.tbs[i].value[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(m.embed_dim)) - m.tbs[i].out[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(m.embed_dim, m.embed_dim)) - m.tbs[i].out[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(m.embed_dim)) - m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel']) - m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias']) - m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel']) - m.tbs[i].ff2[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias']) - m.tbs[i].ln1[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/scale']) - m.tbs[i].ln1[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/bias']) - m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale']) - m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias']) + for i in range(12): + m.tbs[i].query[0].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel" + ].reshape(m.embed_dim, m.embed_dim) + ) + m.tbs[i].query[1].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias" + ].reshape(m.embed_dim) + ) + m.tbs[i].key[0].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel" + ].reshape(m.embed_dim, m.embed_dim) + ) + m.tbs[i].key[1].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias" + ].reshape(m.embed_dim) + ) + m.tbs[i].value[0].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel" + ].reshape(m.embed_dim, m.embed_dim) + ) + m.tbs[i].value[1].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias" + ].reshape(m.embed_dim) + ) + m.tbs[i].out[0].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel" + ].reshape(m.embed_dim, m.embed_dim) + ) + m.tbs[i].out[1].assign( + dat[ + f"Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias" + ].reshape(m.embed_dim) + ) + m.tbs[i].ff1[0].assign( + dat[f"Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel"] + ) + m.tbs[i].ff1[1].assign( + dat[f"Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias"] + ) + m.tbs[i].ff2[0].assign( + dat[f"Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel"] + ) + m.tbs[i].ff2[1].assign( + dat[f"Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias"] + ) + m.tbs[i].ln1[0].assign( + dat[f"Transformer/encoderblock_{i}/LayerNorm_0/scale"] + ) + m.tbs[i].ln1[1].assign( + dat[f"Transformer/encoderblock_{i}/LayerNorm_0/bias"] + ) + m.tbs[i].ln2[0].assign( + dat[f"Transformer/encoderblock_{i}/LayerNorm_2/scale"] + ) + m.tbs[i].ln2[1].assign( + dat[f"Transformer/encoderblock_{i}/LayerNorm_2/bias"] + ) diff --git a/extra/onnx.py b/extra/onnx.py index 72fbea773..82a4ec9fd 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -5,216 +5,340 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, DEBUG, dtypes from typing import List, Dict -from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors +from onnx import ( + AttributeProto, + ModelProto, + TensorProto, + TypeProto, +) # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors + try: - from onnx.helper import tensor_dtype_to_np_dtype + from onnx.helper import tensor_dtype_to_np_dtype except ImportError: - # for onnx < 1.13 - from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE - tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x] + # for onnx < 1.13 + from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE + + tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x] # global numpy cache for parameters numpy_cache = {} -def safe_numpy(t) -> np.ndarray: - if not isinstance(t, Tensor): return t - global numpy_cache - if t not in numpy_cache: - if DEBUG >= 3: print("numpy cache miss", t) - tmp = t.numpy() - numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1) - assert len(numpy_cache[t].shape) > 0 - return numpy_cache[t] -onnx_ops = importlib.import_module('extra.onnx_ops') + +def safe_numpy(t) -> np.ndarray: + if not isinstance(t, Tensor): + return t + global numpy_cache + if t not in numpy_cache: + if DEBUG >= 3: + print("numpy cache miss", t) + tmp = t.numpy() + numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1) + assert len(numpy_cache[t].shape) > 0 + return numpy_cache[t] + + +onnx_ops = importlib.import_module("extra.onnx_ops") ONNXLIMIT = getenv("ONNXLIMIT", -1) + def get_run_onnx(onnx_model: ModelProto): - def type_parse(type_proto: TypeProto): - ret = [] - while True: - attr = type_proto.WhichOneof('value') - if attr == 'tensor_type': - if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape - elif not ret: - return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim]) + def type_parse(type_proto: TypeProto): + ret = [] + while True: + attr = type_proto.WhichOneof("value") + if attr == "tensor_type": + if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): + return () # variable type, unable to determine shape + elif not ret: + return tuple( + [x.dim_value for x in getattr(type_proto, attr).shape.dim] + ) + else: + ret.extend( + [(x.dim_value,) for x in getattr(type_proto, attr).shape.dim] + ) + return tuple(ret) + elif attr == "sequence_type": + type_proto = getattr(type_proto, attr).elem_type + ret.append(1) + elif attr == "map_type": + raise NotImplementedError(f"map_type is not implemented: {type_proto}") + elif attr == "opaque_type": + raise NotImplementedError( + f"opaque_type is not implemented: {type_proto}" + ) + elif attr == "sparse_tensor_type": + raise NotImplementedError( + f"sparse_tensor_type is not implemented: {type_proto}" + ) + elif attr == "optional_type": + type_proto = getattr(type_proto, attr).elem_type + else: + raise Exception(f"unknown attr: {attr}, {type_proto}") + + def buffer_parse(inp: TensorProto) -> Tensor: + if inp.data_type in (1, 10, 6, 7, 5): + # TODO: this is shared with below + if len(inp.float_data) > 0: + ret = Tensor( + np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), + requires_grad=False, + ) + elif len(inp.int64_data) > 0: + ret = Tensor( + np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), + requires_grad=False, + ) + elif len(inp.int32_data) > 0: + ret = Tensor( + np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), + requires_grad=False, + ) + else: + ret = Tensor( + np.frombuffer( + inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type) + ) + .reshape(inp.dims) + .astype(np.float32) + .copy(), + requires_grad=False, + ) else: - ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim]) - return tuple(ret) - elif attr == 'sequence_type': - type_proto = getattr(type_proto, attr).elem_type - ret.append(1) - elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}") - elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}") - elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}") - elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type - else: raise Exception(f"unknown attr: {attr}, {type_proto}") + raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}") + return ret - def buffer_parse(inp: TensorProto) -> Tensor: - if inp.data_type in (1,10,6,7,5): - # TODO: this is shared with below - if len(inp.float_data) > 0: - ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) - elif len(inp.int64_data) > 0: - ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False) - elif len(inp.int32_data) > 0: - ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False) - else: - ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False) - else: - raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}") - return ret - - def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]: - # TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list - if a.type == AttributeProto.FLOAT: return float(a.f) - elif a.type == AttributeProto.INT: return int(a.i) - elif a.type == AttributeProto.STRING: return a.s.decode("utf-8") - elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR - elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats) - elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints) - elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings) - elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow") - else: raise Exception(f"can't parse {a.type} {a}") - def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a} - - tensors: Dict[str, Tensor] = {} - - # get weights and biases - for inp in onnx_model.graph.initializer: - if len(inp.raw_data) > 0: - tensors[inp.name] = buffer_parse(inp) - elif len(inp.float_data) > 0: - tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) - elif len(inp.int64_data) > 0: - tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False) - elif len(inp.raw_data) == 0: - tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False) - else: - print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) - print(inp) - raise Exception("no data") - - # preparse the attributes - attribute_dict = {} - domain = "" - for num,n in enumerate(onnx_model.graph.node): - attribute_dict[num] = attribute_to_dict(n.attribute) - if n.domain: domain = n.domain - - onnx_model_version = onnx_model.opset_import[0].version - - def run_onnx(inputs={}, debug=0): - debug = getenv("DEBUGONNX") or debug - input_tensors: Dict[str,Tensor] = {} - intermediate_tensors: Dict[str,Tensor] = {} - output_tensor_names = [x.name for x in onnx_model.graph.output] - - # get inputs - for inp in onnx_model.graph.input: - if inp.name in tensors: continue - shape = type_parse(inp.type) - if inp.name in inputs: - if isinstance(inputs[inp.name], Tensor): - input_tensors[inp.name] = inputs[inp.name] - elif isinstance(inputs[inp.name], list): - input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]] - elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training" - input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops + def attribute_parse( + a: AttributeProto, + ) -> float | int | str | Tensor | tuple[float] | tuple[int]: + # TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list + if a.type == AttributeProto.FLOAT: + return float(a.f) + elif a.type == AttributeProto.INT: + return int(a.i) + elif a.type == AttributeProto.STRING: + return a.s.decode("utf-8") + elif a.type == AttributeProto.TENSOR: + return buffer_parse(a.t) # TENSOR + elif a.type == AttributeProto.FLOATS: + return tuple(float(x) for x in a.floats) + elif a.type == AttributeProto.INTS: + return tuple(int(x) for x in a.ints) + elif a.type == AttributeProto.STRINGS: + return tuple(x.decode("utf-8") for x in a.strings) + elif a.type == AttributeProto.GRAPH: + raise Exception( + f"graph not implemented: {a.g}\n likely an OP requiring control flow" + ) else: - input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False) - if shape: # if only input_tensor is not variable type - input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]]) - assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" - else: - raise Exception(f"no data for {inp.name} with shape {shape}") + raise Exception(f"can't parse {a.type} {a}") - def fetch_tensor(x: str): - if x in tensors: return tensors[x] - if x in intermediate_tensors: return intermediate_tensors[x] - if x != str(): return input_tensors[x] - return None + def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): + return {x.name: attribute_parse(x) for x in a} - for num,n in enumerate(onnx_model.graph.node): - inp: List[Tensor] = [] - if debug >= 3: print("inputs:") - for x in n.input: - t = fetch_tensor(x) - if debug >= 3: print(f"\t{x} - {t}") - inp.append(t) - opt: Dict = attribute_dict[num] - if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}") + tensors: Dict[str, Tensor] = {} - # NOTE some ops live here because they require access to some local variables - # have to use n.output for cases when num_outputs is absent - if n.op_type in onnx_ops.tensor_methods: - ret = getattr(Tensor, n.op_type.lower())(*inp, **opt) - elif n.op_type == "Split": - axis = opt.get("axis", 0) - split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])] - if split is None: - split = [inp[0].shape[axis] // len(n.output)] * len(n.output) - for i in range(inp[0].shape[axis] % len(n.output)): - split[i] += 1 - i, ret = 0, [] - arg = [(0,x) for x in inp[0].shape] - for s in split: - arg[axis] = (i,i+s) - ret.append(inp[0].shrink(arg=tuple(arg))) - i = i+s - ret = tuple(ret) - - # need to check onnx_model_version - elif n.op_type == "Slice": - if onnx_model_version < 10: - axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim + # get weights and biases + for inp in onnx_model.graph.initializer: + if len(inp.raw_data) > 0: + tensors[inp.name] = buffer_parse(inp) + elif len(inp.float_data) > 0: + tensors[inp.name] = Tensor( + np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), + requires_grad=False, + ) + elif len(inp.int64_data) > 0: + tensors[inp.name] = Tensor( + np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), + requires_grad=False, + ) + elif len(inp.raw_data) == 0: + tensors[inp.name] = Tensor( + np.array([], dtype=np.float32), requires_grad=False + ) else: - starts, ends = inp[1:3] - axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist() - steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim - starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist() - arg = [(0,x,1) for x in inp[0].shape] - for i, axis in enumerate(axes): - axis = int(axis) + inp[0].ndim if axis < 0 else int(axis) - starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i] - starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis])) - if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i] - arg[axis] = (starts[i], ends[i], steps[i]) - new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg) - if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape) - else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg])) + print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) + print(inp) + raise Exception("no data") - # need to call backward on intermediate_tensors - elif n.op_type == "Gradient": - assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match" - y = opt["y"] - intermediate_tensors[y].backward() - ret = tuple([t.grad for t in inp]) + # preparse the attributes + attribute_dict = {} + domain = "" + for num, n in enumerate(onnx_model.graph.node): + attribute_dict[num] = attribute_to_dict(n.attribute) + if n.domain: + domain = n.domain - # onnx_ops.py - elif hasattr(onnx_ops, n.op_type): - fxn = getattr(onnx_ops, n.op_type) - if isinstance(fxn, dict): - for k in sorted(fxn.keys()): - if k <= onnx_model_version: - real_fxn = fxn[k] - else: - real_fxn = fxn - ret = real_fxn(*inp, **opt) - else: - print("UNSUPPORTED", n.op_type, n.input, n.output) - raise Exception(f"op_type {n.op_type} not supported") + onnx_model_version = onnx_model.opset_import[0].version - if not isinstance(ret, tuple): ret = (ret, ) - assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}" - if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret]) - if debug >= 2: print("outputs:") - for i in range(len(n.output)): - if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}") - intermediate_tensors[n.output[i]] = ret[i] - if num == ONNXLIMIT: - output_tensor_names = n.output - break + def run_onnx(inputs={}, debug=0): + debug = getenv("DEBUGONNX") or debug + input_tensors: Dict[str, Tensor] = {} + intermediate_tensors: Dict[str, Tensor] = {} + output_tensor_names = [x.name for x in onnx_model.graph.output] - return {outp:intermediate_tensors[outp] for outp in output_tensor_names} - return run_onnx + # get inputs + for inp in onnx_model.graph.input: + if inp.name in tensors: + continue + shape = type_parse(inp.type) + if inp.name in inputs: + if isinstance(inputs[inp.name], Tensor): + input_tensors[inp.name] = inputs[inp.name] + elif isinstance(inputs[inp.name], list): + input_tensors[inp.name] = [ + Tensor(i, requires_grad=False) for i in inputs[inp.name] + ] + elif ( + domain == "ai.onnx.preview.training" + ): # not sure if in real use the domain is "ai.onnx.preview.training" + input_tensors[inp.name] = Tensor( + inputs[inp.name], requires_grad=True + ) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops + else: + input_tensors[inp.name] = Tensor( + inputs[inp.name], requires_grad=False + ) + if shape: # if only input_tensor is not variable type + input_shape = ( + input_tensors[inp.name].shape + if isinstance(input_tensors[inp.name], Tensor) + else (1, *[i.shape for i in input_tensors[inp.name]]) + ) + assert ( + input_shape == shape + ), f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" + else: + raise Exception(f"no data for {inp.name} with shape {shape}") + + def fetch_tensor(x: str): + if x in tensors: + return tensors[x] + if x in intermediate_tensors: + return intermediate_tensors[x] + if x != str(): + return input_tensors[x] + return None + + for num, n in enumerate(onnx_model.graph.node): + inp: List[Tensor] = [] + if debug >= 3: + print("inputs:") + for x in n.input: + t = fetch_tensor(x) + if debug >= 3: + print(f"\t{x} - {t}") + inp.append(t) + opt: Dict = attribute_dict[num] + if debug >= 1: + print( + f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}" + ) + + # NOTE some ops live here because they require access to some local variables + # have to use n.output for cases when num_outputs is absent + if n.op_type in onnx_ops.tensor_methods: + ret = getattr(Tensor, n.op_type.lower())(*inp, **opt) + elif n.op_type == "Split": + axis = opt.get("axis", 0) + split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])] + if split is None: + split = [inp[0].shape[axis] // len(n.output)] * len(n.output) + for i in range(inp[0].shape[axis] % len(n.output)): + split[i] += 1 + i, ret = 0, [] + arg = [(0, x) for x in inp[0].shape] + for s in split: + arg[axis] = (i, i + s) + ret.append(inp[0].shrink(arg=tuple(arg))) + i = i + s + ret = tuple(ret) + + # need to check onnx_model_version + elif n.op_type == "Slice": + if onnx_model_version < 10: + axes, ends, starts, steps = ( + list(opt.get("axes", range(inp[0].ndim))), + list(opt["ends"]), + list(opt["starts"]), + [1] * inp[0].ndim, + ) + else: + starts, ends = inp[1:3] + axes = safe_numpy( + Tensor.arange(inp[0].ndim, dtype=dtypes.int32) + if len(inp) <= 3 + else inp[3] + ).tolist() + steps = safe_numpy(inp[4]) if len(inp) > 4 else [1] * inp[0].ndim + starts, ends = ( + safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), + safe_numpy(ends.ceil().cast(dtypes.int32)).tolist(), + ) + arg = [(0, x, 1) for x in inp[0].shape] + for i, axis in enumerate(axes): + axis = int(axis) + inp[0].ndim if axis < 0 else int(axis) + starts[i], ends[i] = ( + starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], + ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i], + ) + starts[i], ends[i] = max( + 0, min(starts[i], inp[0].shape[axis]) + ), max(0, min(ends[i], inp[0].shape[axis])) + if starts[i] > ends[i] and steps[i] >= 0: + steps[i] = -steps[i] + arg[axis] = (starts[i], ends[i], steps[i]) + new_shape = tuple( + (s, e) if st > 0 else (e + 1, s + 1) for s, e, st in arg + ) + if any(s == e for s, e in new_shape): + ret = inp[0].shrink(new_shape) + else: + ret = inp[0].__getitem__( + tuple([slice(s, e, st) for s, e, st in arg]) + ) + + # need to call backward on intermediate_tensors + elif n.op_type == "Gradient": + assert len(opt["xs"]) == len( + inp + ), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match" + y = opt["y"] + intermediate_tensors[y].backward() + ret = tuple([t.grad for t in inp]) + + # onnx_ops.py + elif hasattr(onnx_ops, n.op_type): + fxn = getattr(onnx_ops, n.op_type) + if isinstance(fxn, dict): + for k in sorted(fxn.keys()): + if k <= onnx_model_version: + real_fxn = fxn[k] + else: + real_fxn = fxn + ret = real_fxn(*inp, **opt) + else: + print("UNSUPPORTED", n.op_type, n.input, n.output) + raise Exception(f"op_type {n.op_type} not supported") + + if not isinstance(ret, tuple): + ret = (ret,) + assert len(n.output) <= len( + ret + ), f"expected output size must be less than {len(ret)}, it's {n.output}" + if debug >= 2: + print([x.shape if isinstance(x, Tensor) else None for x in ret]) + if debug >= 2: + print("outputs:") + for i in range(len(n.output)): + if debug >= 2: + print(f"\t{n.output[i]} - {ret[i]}") + intermediate_tensors[n.output[i]] = ret[i] + if num == ONNXLIMIT: + output_tensor_names = n.output + break + + return {outp: intermediate_tensors[outp] for outp in output_tensor_names} + + return run_onnx diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 22a2d61a7..e05b111d5 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -10,745 +10,1825 @@ import functools from typing import Union, Tuple, Optional, List, Any import math -tensor_methods = {"Neg", "Reciprocal", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "Tanh", "MatMul", - "Floor", "Ceil", "Tanh", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Softsign", "Asinh", "Acosh", "Atanh"} +tensor_methods = { + "Neg", + "Reciprocal", + "Sqrt", + "Sign", + "Abs", + "Exp", + "Log", + "Mish", + "Sin", + "Cos", + "Tan", + "Relu", + "Sigmoid", + "Tanh", + "MatMul", + "Floor", + "Ceil", + "Tanh", + "Softplus", + "HardSwish", + "Where", + "Mul", + "Sinh", + "Cosh", + "Softsign", + "Asinh", + "Acosh", + "Atanh", +} # **************** Free Ops **************** -def Identity(input: Tensor): return input -def Add(input: Tensor, other: Tensor, broadcast=None): return input + other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else (input + other).cast(input.dtype) -def Sub(input: Union[Tensor, Any], other: Tensor): return input - other # some test has input as int -def Div(input: Tensor, other: Tensor): return input / other if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) else input.div(other).floor() # TODO: this has dtype issues -def Pow(input: Tensor, other: Tensor): return (input.float() ** other.float()).cast(input.dtype) # TODO: this has dtype issues -def Less(x:Tensor,y:Tensor): return (xy).cast(dtypes.bool) -def GreaterOrEqual(x:Tensor,y:Tensor): return (x>=y).cast(dtypes.bool) -def Equal(x:Tensor,y:Tensor): return (x==y).cast(dtypes.bool) -def Max(*data_0): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0) -def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0) -def Cast(input: Tensor, to): return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to))) + +def Identity(input: Tensor): + return input + + +def Add(input: Tensor, other: Tensor, broadcast=None): + return ( + input + other + if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) + else (input + other).cast(input.dtype) + ) + + +def Sub(input: Union[Tensor, Any], other: Tensor): + return input - other # some test has input as int + + +def Div(input: Tensor, other: Tensor): + return ( + input / other + if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType) + else input.div(other).floor() + ) # TODO: this has dtype issues + + +def Pow(input: Tensor, other: Tensor): + return (input.float() ** other.float()).cast( + input.dtype + ) # TODO: this has dtype issues + + +def Less(x: Tensor, y: Tensor): + return (x < y).cast(dtypes.bool) + + +def LessOrEqual(x: Tensor, y: Tensor): + return (x <= y).cast(dtypes.bool) + + +def Greater(x: Tensor, y: Tensor): + return (x > y).cast(dtypes.bool) + + +def GreaterOrEqual(x: Tensor, y: Tensor): + return (x >= y).cast(dtypes.bool) + + +def Equal(x: Tensor, y: Tensor): + return (x == y).cast(dtypes.bool) + + +def Max(*data_0): + return functools.reduce(Tensor.maximum, data_0) + + +def Min(*data_0): + return functools.reduce(Tensor.minimum, data_0) + + +def Sum(*data_0): + return functools.reduce(Tensor.__add__, data_0) + + +def Mean(*data_0): + return functools.reduce(Tensor.__add__, data_0) / len(data_0) + + +def Cast(input: Tensor, to): + return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to))) + # **************** Simple Ops **************** -def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None): - if value: return value - elif value_float: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) - elif value_floats: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) - elif value_int: return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) - elif value_ints: return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) - elif value_string or value_strings: raise NotImplementedError(f'value_string or value_strings not implemented for Constant op') -def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1) -def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x/math.sqrt(2))) -def Celu(x:Tensor, alpha=1.0): return x.celu(alpha) -def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) -def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE - return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope -def LeakyRelu(X: Tensor, alpha=0.01): return X.leakyrelu(alpha) -def ThresholdedRelu(X: Tensor, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha -def Softmax_1(input: Tensor, axis=1): return input.softmax(axis) -def Softmax_13(input: Tensor, axis=-1): return input.softmax(axis) -Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed -def LogSoftmax(input: Tensor, axis=-1): return input.log_softmax(axis) -def Clip(input: Tensor, min=None, max=None): return input.clip(float('-inf') if min is None else min, float('inf') if max is None else max) +def Constant( + value: Tensor = None, + value_float=None, + value_floats=None, + value_int=None, + value_ints=None, + value_string=None, + value_strings=None, +): + if value: + return value + elif value_float: + return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) + elif value_floats: + return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) + elif value_int: + return Tensor(value_int, dtype=dtypes.int64, requires_grad=False) + elif value_ints: + return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False) + elif value_string or value_strings: + raise NotImplementedError( + f"value_string or value_strings not implemented for Constant op" + ) + + +def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5): + return (alpha * input + beta).clip(0, 1) + + +def Gelu(x: Tensor, approximate=None): + return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x / math.sqrt(2))) + + +def Celu(x: Tensor, alpha=1.0): + return x.celu(alpha) + + +def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): + return gamma * (X.relu() - (-alpha * X.exp() + alpha).relu()) + + +def PRelu(X: Tensor, slope: Tensor): + slope = ( + slope[0] if slope.shape[-1] != X.shape[-1] else slope + ) # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE + return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope + + +def LeakyRelu(X: Tensor, alpha=0.01): + return X.leakyrelu(alpha) + + +def ThresholdedRelu(X: Tensor, alpha=1.0): + return (X - alpha).relu() + (X - alpha).relu().sign() * alpha + + +def Softmax_1(input: Tensor, axis=1): + return input.softmax(axis) + + +def Softmax_13(input: Tensor, axis=-1): + return input.softmax(axis) + + +Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed + + +def LogSoftmax(input: Tensor, axis=-1): + return input.log_softmax(axis) + + +def Clip(input: Tensor, min=None, max=None): + return input.clip( + float("-inf") if min is None else min, float("inf") if max is None else max + ) + # NOTE ReduceProd would require a new llop -def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) else ([] if noop_with_empty_axes else None) -def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) -def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt() -def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log() -def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log() +def _axes(axes, noop_with_empty_axes): + return ( + [int(x) for x in safe_numpy(axes)] + if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,)) + else ([] if noop_with_empty_axes else None) + ) -def GlobalAveragePool(X: Tensor): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True) -def GlobalMaxPool(X: Tensor): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True) -def OptionalHasElement(x: Tensor=None): return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool) -def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([], dtype=dtypes.float32) -def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)]) -def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start).item()), stop=int(safe_numpy(limit).item()), step=int(safe_numpy(delta).item())).cast(dtype=start.dtype) -def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really? -def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape) -def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1) -def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))]) -def Shrink(input: Tensor, bias=0.0, lambd=0.5): return (input < -lambd)*(input+bias) + (input > lambd)*(input-bias) -def And(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool) -def Or(x:Tensor, y:Tensor): return (x==y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool) -def Xor(x:Tensor, y:Tensor): return (x==y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) -def Not(x:Tensor): return (x==1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) +def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) + + +def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt() + + +def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log() + + +def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0): + return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log() + + +def GlobalAveragePool(X: Tensor): + return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True) + + +def GlobalMaxPool(X: Tensor): + return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True) + + +def OptionalHasElement(x: Tensor = None): + return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool) + + +def OptionalGetElement(x: Tensor = None): + return x if x is not None else Tensor([], dtype=dtypes.float32) + + +def Tile(input: Tensor, repeats): + return input.repeat([int(x) for x in safe_numpy(repeats)]) + + +def Range(start: Tensor, limit, delta): + return Tensor.arange( + start=int(safe_numpy(start).item()), + stop=int(safe_numpy(limit).item()), + step=int(safe_numpy(delta).item()), + ).cast(dtype=start.dtype) + + +def Shape(data: Tensor, end=None, start=0): + return Tensor( + list(data.shape)[start:end], + dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64, + ) # TODO: really? + + +def Size(data: Tensor): + return prod(data if isinstance(data, list) else data.shape) + + +def Flatten(input: Tensor, axis=1): + return input.reshape(prod((1,) + input.shape[0:axis]), -1) + + +def Reshape(data: Tensor, shape: Tensor, allowzero=None): + return data.reshape( + [int(x) if x != 0 else data.shape[i] for i, x in enumerate(safe_numpy(shape))] + ) + + +def Shrink(input: Tensor, bias=0.0, lambd=0.5): + return (input < -lambd) * (input + bias) + (input > lambd) * (input - bias) + + +def And(x: Tensor, y: Tensor): + return (x == y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool) + + +def Or(x: Tensor, y: Tensor): + return (x == y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool) + + +def Xor(x: Tensor, y: Tensor): + return ( + (x == y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) + ) + + +def Not(x: Tensor): + return ( + (x == 1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) + ) + + +def Asin(x): + return Atan(x / Tensor.sqrt(1 - x * x)) + -def Asin(x): return Atan(x / Tensor.sqrt(1 - x * x)) def Acos(x: Tensor): - negate = (x < 0) - x = x.abs() - ret = ((((-0.0187293 * x) + 0.0742610)*x - 0.2121144) * x + 1.5707288) * Tensor.sqrt(1.0 - x) - ret = ret - 2 * negate * ret - return negate * 3.14159265358979 + ret -def Atan(y: Tensor): - x = Tensor.ones(y.shape) - t3 = x - t1 = y.abs() - t0 = (t3 > t1).where(t3, t1) - t1 = (t3 < t1).where(t3, t1) - t3 = t1 / t0 - t4 = t3 * t3 - t0 = ((((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 - 0.332994597) * t4 + 0.999995630 - t3 = t0 * t3 - t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3) - return (y < 0).where(-t3, t3) + negate = x < 0 + x = x.abs() + ret = ( + (((-0.0187293 * x) + 0.0742610) * x - 0.2121144) * x + 1.5707288 + ) * Tensor.sqrt(1.0 - x) + ret = ret - 2 * negate * ret + return negate * 3.14159265358979 + ret + + +def Atan(y: Tensor): + x = Tensor.ones(y.shape) + t3 = x + t1 = y.abs() + t0 = (t3 > t1).where(t3, t1) + t1 = (t3 < t1).where(t3, t1) + t3 = t1 / t0 + t4 = t3 * t3 + t0 = ( + (((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4 + - 0.332994597 + ) * t4 + 0.999995630 + t3 = t0 * t3 + t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3) + return (y < 0).where(-t3, t3) + + +def Trilu(x: Tensor, k: Union[Tensor, int] = 0, upper=1): + k = ( + int(k.numpy().item()) if k != 0 else 0 + ) # onnx passes k as a tensor int64 with one element, default is 0 + return x.triu(k) if upper else x.tril(k) -def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): - k = int(k.numpy().item()) if k != 0 else 0 # onnx passes k as a tensor int64 with one element, default is 0 - return x.triu(k) if upper else x.tril(k) def Squeeze(input: Tensor, axes): - if isinstance(axes, Tensor): axes = safe_numpy(axes) - axes = [int(x) if x >= 0 else int(x+input.ndim) for x in axes] - return input.reshape([s for i,s in enumerate(input.shape) if i not in axes]) -def Unsqueeze(data: Tensor, axes): - axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)] - new_shape = [1] * (len(data.shape) + len(axes)) - ptr = iter(data.shape) - for i in range(len(new_shape)): - if i not in axes: - new_shape[i] = next(ptr) - return data.reshape(new_shape) + if isinstance(axes, Tensor): + axes = safe_numpy(axes) + axes = [int(x) if x >= 0 else int(x + input.ndim) for x in axes] + return input.reshape([s for i, s in enumerate(input.shape) if i not in axes]) + + +def Unsqueeze(data: Tensor, axes): + axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)] + new_shape = [1] * (len(data.shape) + len(axes)) + ptr = iter(data.shape) + for i in range(len(new_shape)): + if i not in axes: + new_shape[i] = next(ptr) + return data.reshape(new_shape) + + +def Binarizer(input, threshold=0.0): + return input > threshold -def Binarizer(input, threshold=0.0): return input > threshold def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0): - axis = axis + x.ndim if axis < 0 else axis - m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis)) - c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m - return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) + axis = axis + x.ndim if axis < 0 else axis + m = x == ( + x.max(axis=axis, keepdim=keepdims) + if keepdims + else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis) + ) + c = ( + Tensor.arange(x.shape[axis]).reshape( + *[1] * (axis), x.shape[axis], *[1] * (x.ndim - axis - 1) + ) + * m + ) + return c.max(axis=axis, keepdim=keepdims).cast(dtypes.int64) + + +def ArgMin(x, axis=0, keepdims=1, select_last_index=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) + + +def Elu(input: Tensor, alpha=1.0): + return input.elu(alpha=alpha) + + +def Concat(*inputs: List[Tensor], axis): + return inputs[0].cat(*inputs[1:], dim=axis) + + +def Transpose(input: Tensor, perm=None): + return input.permute( + order=list(range(len(input.shape))[::-1]) if perm is None else perm + ) -def Elu(input: Tensor, alpha=1.0): return input.elu(alpha=alpha) -def Concat(*inputs: List[Tensor], axis): return inputs[0].cat(*inputs[1:], dim=axis) -def Transpose(input: Tensor, perm=None): return input.permute(order=list(range(len(input.shape))[::-1]) if perm is None else perm) # NOTE: since we only have one type, this is valid! def CastLike(input, target_type): - assert isinstance(target_type, Tensor), "can only CastLike Tensor" - return input + assert isinstance(target_type, Tensor), "can only CastLike Tensor" + return input + + +def ConstantOfShape(input, value: Tensor = None): + if value is None: + value = Tensor([0.0]) + shape = [int(x) for x in safe_numpy(input)] + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0] != 0 else 1) -def ConstantOfShape(input, value:Tensor=None): - if value is None: value=Tensor([0.0]) - shape = [int(x) for x in safe_numpy(input)] - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1) # TODO: abstract out the broadcast logic in tensor def Expand(input: Tensor, shape): - x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)] - # copied from _broadcasted - x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]] - shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape)) - return input.reshape(x_shape).expand(shape_ret) + x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)] + # copied from _broadcasted + x_shape, y_shape = [ + ([1] * (max(len(x_shape), len(y_shape)) - len(t_shape)) + list(t_shape)) + for t_shape in [x_shape, y_shape] + ] + shape_ret = tuple(max(sx, sy) for sx, sy in zip(x_shape, y_shape)) + return input.reshape(x_shape).expand(shape_ret) + # **************** Complex Ops **************** -def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1])) - return ret + +def Gemm( + A: Tensor, + B: Tensor, + C: Tensor = None, + alpha=1.0, + beta=1.0, + transA=0, + transB=0, + broadcast=0, +): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: + ret += beta * ( + C + if broadcast == 0 + else C.reshape( + [-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1] + ) + ) + return ret + # works with Tensors.ndim != 4 -def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor): - shape = [1, -1] + [1] * (self.ndim-2) - x = (self - mean.reshape(shape=shape)) - if weight: x = x * weight.reshape(shape=shape) - ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd) - return (ret + bias.reshape(shape=shape)) if bias else ret +def _batchnorm( + self: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + mean: Tensor, + invstd: Tensor, +): + shape = [1, -1] + [1] * (self.ndim - 2) + x = self - mean.reshape(shape=shape) + if weight: + x = x * weight.reshape(shape=shape) + ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd) + return (ret + bias.reshape(shape=shape)) if bias else ret + # TODO: this is copied from tinygrad/nn/__init__.py # spatial is from opset 7 and has since been removed -def BatchNormalization(X: Tensor, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).pow(-0.5) +def BatchNormalization( + X: Tensor, + scale, + B, + input_mean, + input_var, + epsilon=1e-05, + momentum=0.9, + training_mode=0, + spatial=1, + is_test=0, +): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0, 2, 3)) + y = x_detached - current_mean.reshape(shape=[1, -1, 1, 1]) + current_var = (y * y).mean(axis=(0, 2, 3)) + current_invstd = current_var.add(epsilon).pow(-0.5) - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + + return ( + _batchnorm(X, scale, B, current_mean, current_invstd), + running_mean, + running_var, + ) + else: + invstd = (input_var + epsilon) ** -0.5 + return _batchnorm(X, scale, B, input_mean, invstd) - return _batchnorm(X, scale, B, current_mean, current_invstd), running_mean, running_var - else: - invstd = (input_var + epsilon)**-0.5 - return _batchnorm(X, scale, B, input_mean, invstd) def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05): - axis = tuple(range(2, len(x.shape))) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5) - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) + axis = tuple(range(2, len(x.shape))) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5) + return ( + x.sub(mean) + .mul(scale.reshape(shape=[-1, 1, 1])) + .mul(invstd) + .add(bias.reshape(shape=[-1, 1, 1])) + ) + def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1): - assert stash_type == 1, "only float32 is supported" - axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape))) - mean = x.mean(axis=axis, keepdim=True) - return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal() + assert stash_type == 1, "only float32 is supported" + axis = tuple( + i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)) + ) + mean = x.mean(axis=axis, keepdim=True) + return ( + x.layernorm(axis, epsilon).mul(scale).add(bias), + mean, + (x.sub(mean)) + .pow(2) + .mean(axis=axis, keepdim=True) + .add(epsilon) + .sqrt() + .reciprocal(), + ) + + +def GroupNormalization( + x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05 +): + return ( + x.reshape(x.shape[0], num_groups, -1) + .layernorm(axis=-1, eps=epsilon) + .mul(scale.unsqueeze(-1)) + .add(bias.unsqueeze(-1)) + .reshape(x.shape) + ) -def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) # onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...] # numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...) def _format_padding(onnx_pads, ndims=None, axes=None): - if ndims and len(onnx_pads)//2 != ndims: onnx_pads = onnx_pads * ndims # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2 - if ndims is None: ndims = len(onnx_pads) // 2 - if axes is None: axes = list(range(ndims)) - num_axes = len(axes) - np_pads = [(0,0)] * ndims - for i in range(num_axes): - np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes]) - return np_pads + if ndims and len(onnx_pads) // 2 != ndims: + onnx_pads = ( + onnx_pads * ndims + ) # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2 + if ndims is None: + ndims = len(onnx_pads) // 2 + if axes is None: + axes = list(range(ndims)) + num_axes = len(axes) + np_pads = [(0, 0)] * ndims + for i in range(num_axes): + np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes]) + return np_pads + + +def _padding( + X: Tensor, + pads=None, + auto_pad="NOTSET", + axes=None, + constant_value=0.0, + strides=None, + kernel_shape=None, + dilations=None, + ceil_mode=0, +): + if auto_pad != "NOTSET": + pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) + elif ceil_mode and auto_pad == "NOTSET": # stupid ceil_mode case + if strides is not None: + strides = ( + [strides] * len(kernel_shape) + if isinstance(strides, int) + else strides + if strides + else [1] * len(kernel_shape) + ) + if dilations is not None: + dilations = [1] * len(kernel_shape) if dilations == 1 else dilations + out_spatial_shape = [ + math.ceil((sh - dil * (ker - 1) - 1) / st + 1) + if ceil_mode + else math.floor((sh - dil * (ker - 1) - 1) / st + 1) + for sh, st, ker, dil in zip( + X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations + ) + ] + pad_shape = [ + (osh - 1) * st + ((ks - 1) * dil + 1) - ish + for osh, st, ks, dil, ish in zip( + out_spatial_shape, + strides, + kernel_shape, + dilations, + X.shape[-len(kernel_shape) :], + ) + ] + pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape]) + pads = pad_shape[::2] + pad_shape[1::2] + if pads is None: + return X + pads = _format_padding(pads, ndims=len(X.shape), axes=axes) + return X.pad(tuple(pads), value=constant_value) -def _padding(X: Tensor, pads=None, auto_pad="NOTSET", axes=None, constant_value=0., strides=None, kernel_shape=None, dilations=None, ceil_mode=0): - if auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) - elif ceil_mode and auto_pad=="NOTSET": # stupid ceil_mode case - if strides is not None: strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape) - if dilations is not None: dilations = [1]*len(kernel_shape) if dilations == 1 else dilations - out_spatial_shape = [math.ceil((sh - dil * (ker-1)-1)/st + 1) if ceil_mode else math.floor((sh - dil * (ker-1)-1)/st + 1) for sh, st, ker, dil in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)] - pad_shape = [(osh-1)*st+((ks-1)*dil+1)-ish for osh, st, ks, dil, ish in zip(out_spatial_shape, strides, kernel_shape, dilations, X.shape[-len(kernel_shape):])] - pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape]) - pads = pad_shape[::2] + pad_shape[1::2] - if pads is None: return X - pads = _format_padding(pads, ndims=len(X.shape), axes=axes) - return X.pad(tuple(pads), value=constant_value) def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations): - strides = [strides]*len(kernel_shape) if isinstance(strides, int) else strides if strides else [1]*len(kernel_shape) - dilations = [1]*len(kernel_shape) if dilations == 1 else dilations - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - pad_shape = [(math.ceil(sh/st)-1)*st+((ks-1)*di+1)-sh for sh, st, ks, di in zip(X.shape[-len(kernel_shape):], strides, kernel_shape, dilations)] - pad_shape = flatten([[sh//2, sh-sh//2] for sh in pad_shape]) - return pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2] - else: raise NotImplementedError(f"auto_pad={auto_pad} not implemented") - -def Pad(x: Tensor, pads: Union[Tensor, Tuple[int, ...]], constant_value: Tensor=None, axes: Tensor=None, mode="constant", value: float=0.): - constant_value = value if constant_value is None else float(safe_numpy(constant_value)[0]) - seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads) - seq_pads = [math.ceil(i) for i in seq_pads] - seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None - base_shape = x.shape - pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes) - if mode == "wrap": - repeat_args = [math.ceil(dim[0]/sh) + math.ceil(dim[1]/sh) + 1 for dim, sh in zip(pads, base_shape)] - new_shape = [s*r for s,r in zip(base_shape, repeat_args)] - shrink_args = [(sh-dim[0]%sh if dim[0]%sh != 0 else 0, nsh-(sh-dim[1]%sh) if dim[1]%sh != 0 else nsh) for dim, sh, nsh in zip(pads, base_shape, new_shape)] - return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args)) - elif mode == "reflect": - for i,s in enumerate(x.shape): - if pads[i] == (0,0): continue - elif pads[i][0] and not pads[i][1]: - x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - elif not pads[i][0] and pads[i][1]: - x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s,0) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - else: - x = x.flip(i).shrink(tuple([(0,s_) if i_ != i else (s-pads[i][0]-1, s_-1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \ - x.flip(i).shrink(tuple([(0,s_) if i_ != i else (1, pads[i][1]+1) for i_,s_ in enumerate(x.shape)])).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - return x - elif mode == "edge": - for i,s in enumerate(x.shape): - if pads[i] == (0,0): continue - elif pads[i][0] and not pads[i][1]: - x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - elif not pads[i][0] and pads[i][1]: - x = x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - else: - x = x.shrink(tuple([(0,s_) if i_ != i else (0,1) for i_,s_ in enumerate(x.shape)])).expand([pads[i][0] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (0,s+pads[i][1]) for i_ in range(x.ndim)])) + \ - x.shrink(tuple([(0,s_) if i_ != i else (s_-1, s_) for i_,s_ in enumerate(x.shape)])).expand([pads[i][1] if i_ == i else s_ for i_,s_ in enumerate(x.shape)]).pad(tuple([(0,0) if i_ != i else (s+pads[i][0],0) for i_ in range(x.ndim)])) + \ - x.pad(tuple([(0,0) if i_ != i else pads[i] for i_ in range(x.ndim)])) - return x - elif mode == "constant": - return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value) - -def AveragePool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_pad=0, dilations=1, pads=None, strides=1): - pixel_axes = tuple(range(len(X.shape)))[2:] - ret = _padding(X, pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations) - if count_include_pad: - return ret - else: - div = _padding(Tensor.ones(*X.shape), pads, auto_pad, axes=pixel_axes, strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode).avg_pool2d(kernel_shape, stride=strides, dilation=dilations) - return ret / div - -def MaxPool(X: Tensor, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1): - ret = _padding(X, pads, auto_pad, constant_value=float("-inf"), axes=tuple(range(len(X.shape)))[2:], strides=strides, kernel_shape=kernel_shape, dilations=dilations, ceil_mode=ceil_mode) - ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations) - ret_len, X_len = ret.numel(), X.numel() - indices = ((ret.flatten().unsqueeze(1).expand(ret_len, X_len) == X.flatten().reshape(1, X_len).expand(ret_len, X_len)) * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)).sum(1).reshape(ret.shape).cast(dtypes.int64) - if storage_order: indices = indices.transpose(indices.ndim-2, indices.ndim-1) - return ret, indices - -def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Tensor=None, kernel_shape=None, pads=None, strides=None): - out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] - outlength = prod(out_sh) - xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength) - arange = Tensor.arange(outlength, requires_grad=False).reshape(1, outlength).expand(xI.shape) - xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength) - ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh) - if outshape is not None: - outshape = safe_numpy(outshape).tolist() - if outshape != ret.shape: - diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]] - pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2] - ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2])) - return ret - -def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): - if auto_pad != "NOTSET": padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) - else: padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding) - -def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1): - if kernel_shape is None: kernel_shape = W.shape[2:] - if isinstance(strides, int): strides = [strides]*(W.ndim-2) - if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2) - if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2) - out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else [] - if pads is None: - if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)] - if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2) + strides = ( + [strides] * len(kernel_shape) + if isinstance(strides, int) + else strides + if strides + else [1] * len(kernel_shape) + ) + dilations = [1] * len(kernel_shape) if dilations == 1 else dilations + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": + pad_shape = [ + (math.ceil(sh / st) - 1) * st + ((ks - 1) * di + 1) - sh + for sh, st, ks, di in zip( + X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations + ) + ] + pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape]) + return ( + pad_shape[::2] + pad_shape[1::2] + if auto_pad == "SAME_UPPER" + else pad_shape[1::2] + pad_shape[::2] + ) else: - total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)] - pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding]) - pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2] - else: - if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] - if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)] - return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding) + raise NotImplementedError(f"auto_pad={auto_pad} not implemented") + + +def Pad( + x: Tensor, + pads: Union[Tensor, Tuple[int, ...]], + constant_value: Tensor = None, + axes: Tensor = None, + mode="constant", + value: float = 0.0, +): + constant_value = ( + value if constant_value is None else float(safe_numpy(constant_value)[0]) + ) + seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads) + seq_pads = [math.ceil(i) for i in seq_pads] + seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None + base_shape = x.shape + pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes) + if mode == "wrap": + repeat_args = [ + math.ceil(dim[0] / sh) + math.ceil(dim[1] / sh) + 1 + for dim, sh in zip(pads, base_shape) + ] + new_shape = [s * r for s, r in zip(base_shape, repeat_args)] + shrink_args = [ + ( + sh - dim[0] % sh if dim[0] % sh != 0 else 0, + nsh - (sh - dim[1] % sh) if dim[1] % sh != 0 else nsh, + ) + for dim, sh, nsh in zip(pads, base_shape, new_shape) + ] + return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args)) + elif mode == "reflect": + for i, s in enumerate(x.shape): + if pads[i] == (0, 0): + continue + elif pads[i][0] and not pads[i][1]: + x = x.flip(i).shrink( + tuple( + [ + (0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ).pad( + tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)]) + ) + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + elif not pads[i][0] and pads[i][1]: + x = x.flip(i).shrink( + tuple( + [ + (0, s_) if i_ != i else (1, pads[i][1] + 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ).pad( + tuple([(0, 0) if i_ != i else (s, 0) for i_ in range(x.ndim)]) + ) + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + else: + x = ( + x.flip(i) + .shrink( + tuple( + [ + (0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ) + .pad( + tuple( + [ + (0, 0) if i_ != i else (0, s + pads[i][1]) + for i_ in range(x.ndim) + ] + ) + ) + + x.flip(i) + .shrink( + tuple( + [ + (0, s_) if i_ != i else (1, pads[i][1] + 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ) + .pad( + tuple( + [ + (0, 0) if i_ != i else (s + pads[i][0], 0) + for i_ in range(x.ndim) + ] + ) + ) + + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + ) + return x + elif mode == "edge": + for i, s in enumerate(x.shape): + if pads[i] == (0, 0): + continue + elif pads[i][0] and not pads[i][1]: + x = x.shrink( + tuple( + [ + (0, s_) if i_ != i else (0, 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ).expand( + [pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)] + ).pad( + tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)]) + ) + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + elif not pads[i][0] and pads[i][1]: + x = x.shrink( + tuple( + [ + (0, s_) if i_ != i else (s_ - 1, s_) + for i_, s_ in enumerate(x.shape) + ] + ) + ).expand( + [pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)] + ).pad( + tuple( + [ + (0, 0) if i_ != i else (s + pads[i][0], 0) + for i_ in range(x.ndim) + ] + ) + ) + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + else: + x = ( + x.shrink( + tuple( + [ + (0, s_) if i_ != i else (0, 1) + for i_, s_ in enumerate(x.shape) + ] + ) + ) + .expand( + [pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)] + ) + .pad( + tuple( + [ + (0, 0) if i_ != i else (0, s + pads[i][1]) + for i_ in range(x.ndim) + ] + ) + ) + + x.shrink( + tuple( + [ + (0, s_) if i_ != i else (s_ - 1, s_) + for i_, s_ in enumerate(x.shape) + ] + ) + ) + .expand( + [pads[i][1] if i_ == i else s_ for i_, s_ in enumerate(x.shape)] + ) + .pad( + tuple( + [ + (0, 0) if i_ != i else (s + pads[i][0], 0) + for i_ in range(x.ndim) + ] + ) + ) + + x.pad( + tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)]) + ) + ) + return x + elif mode == "constant": + return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value) + + +def AveragePool( + X: Tensor, + kernel_shape, + auto_pad="NOTSET", + ceil_mode=0, + count_include_pad=0, + dilations=1, + pads=None, + strides=1, +): + pixel_axes = tuple(range(len(X.shape)))[2:] + ret = _padding( + X, + pads, + auto_pad, + axes=pixel_axes, + strides=strides, + kernel_shape=kernel_shape, + dilations=dilations, + ceil_mode=ceil_mode, + ).avg_pool2d(kernel_shape, stride=strides, dilation=dilations) + if count_include_pad: + return ret + else: + div = _padding( + Tensor.ones(*X.shape), + pads, + auto_pad, + axes=pixel_axes, + strides=strides, + kernel_shape=kernel_shape, + dilations=dilations, + ceil_mode=ceil_mode, + ).avg_pool2d(kernel_shape, stride=strides, dilation=dilations) + return ret / div + + +def MaxPool( + X: Tensor, + kernel_shape, + auto_pad="NOTSET", + ceil_mode=0, + dilations=1, + pads=None, + storage_order=0, + strides=1, +): + ret = _padding( + X, + pads, + auto_pad, + constant_value=float("-inf"), + axes=tuple(range(len(X.shape)))[2:], + strides=strides, + kernel_shape=kernel_shape, + dilations=dilations, + ceil_mode=ceil_mode, + ) + ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations) + ret_len, X_len = ret.numel(), X.numel() + indices = ( + ( + ( + ret.flatten().unsqueeze(1).expand(ret_len, X_len) + == X.flatten().reshape(1, X_len).expand(ret_len, X_len) + ) + * Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len) + ) + .sum(1) + .reshape(ret.shape) + .cast(dtypes.int64) + ) + if storage_order: + indices = indices.transpose(indices.ndim - 2, indices.ndim - 1) + return ret, indices + + +def MaxUnpool( + xT: Tensor, + xI: Tensor, + outshape: Tensor = None, + kernel_shape=None, + pads=None, + strides=None, +): + out_sh = [ + (ks // 2) * 2 + st * inps + for inps, st, ks in zip(xI.shape, strides, kernel_shape) + ] + outlength = prod(out_sh) + xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength) + arange = ( + Tensor.arange(outlength, requires_grad=False) + .reshape(1, outlength) + .expand(xI.shape) + ) + xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength) + ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh) + if outshape is not None: + outshape = safe_numpy(outshape).tolist() + if outshape != ret.shape: + diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]] + pad_args = [ + diff[0] // 2, + diff[1] // 2, + diff[0] - diff[0] // 2, + diff[1] - diff[1] // 2, + ] + ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2])) + return ret + + +def Conv( + X: Tensor, + W: Tensor, + B=None, + auto_pad="NOTSET", + dilations=1, + group=1, + kernel_shape=None, + pads=None, + strides=1, +): + if auto_pad != "NOTSET": + padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) + else: + padding = ( + [ + p + for ps in zip( + pads[: len(pads) // 2][::-1], pads[len(pads) // 2 :][::-1] + ) + for p in ps + ] + if pads is not None + else 0 + ) # reorder padding + return X.conv2d( + W, B, stride=strides, groups=group, dilation=dilations, padding=padding + ) + + +def ConvTranspose( + X: Tensor, + W: Tensor, + B=None, + auto_pad="NOTSET", + dilations=1, + group=1, + kernel_shape=None, + pads=None, + output_shape=None, + output_padding=0, + strides=1, +): + if kernel_shape is None: + kernel_shape = W.shape[2:] + if isinstance(strides, int): + strides = [strides] * (W.ndim - 2) + if isinstance(dilations, int): + dilations = [dilations] * (W.ndim - 2) + if isinstance(output_padding, int): + output_padding = [output_padding] * (W.ndim - 2) + out_sh = ( + [ + st * (xs - 1) + (ks - 1) * di + 1 + if n < 2 + else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1] + for n, (st, xs, ks, di) in enumerate( + zip(strides, X.shape[2:], kernel_shape, dilations) + ) + ] + if output_shape is not None or auto_pad != "NOTSET" + else [] + ) + if pads is None: + if output_shape is None: + output_shape = [xs * st for xs, st in zip(X.shape[2:], strides)] + if auto_pad == "NOTSET": + pads = [0, 0] * (X.ndim - 2) + else: + total_padding = [ + st * (ish - 1) + pad + ((ks - 1) * dil + 1) - osh + for st, ish, pad, ks, dil, osh in zip( + strides, + X.shape[2:], + output_padding, + kernel_shape, + dilations, + output_shape, + ) + ] + pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in total_padding]) + pads = ( + pad_shape[::2] + pad_shape[1::2] + if auto_pad == "SAME_UPPER" + else pad_shape[1::2] + pad_shape[::2] + ) + else: + if output_shape is None: + output_shape = [ + st * (xs - 1) + (ks - 1) * di + 1 + if n < 2 + else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1] + for n, (st, xs, ks, di) in enumerate( + zip(strides, X.shape[2:], kernel_shape, dilations) + ) + ] + if out_sh: + output_padding = [os - rs for os, rs in zip(output_shape, out_sh)] + return X.conv_transpose2d( + W, + B, + stride=strides, + groups=group, + dilation=dilations, + padding=pads if pads is not None else 0, + output_padding=output_padding, + ) + # Reimplemented here because you need legacy RNG for passing ONNX tests. def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None): - if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: () - if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode) - if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - rng = np.random.RandomState(seed) - ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio - mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask + if isinstance(ratio, Tensor) and not ratio.shape: + ratio = safe_numpy( + ratio + ) # ratio and tensor is passed in as Tensor with shape: () + if isinstance(training_mode, Tensor) and not training_mode.shape: + training_mode = safe_numpy(training_mode) + if not training_mode: + return data, Tensor.ones( + *data.shape, dtype=dtypes.bool + ) # if mask is requested as output it will contain all True's. + rng = np.random.RandomState(seed) + ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio + mask = Tensor( + (rng.random(data.shape) >= ratio), requires_grad=False, device=data.device + ) + return data * mask * (1 / (1.0 - ratio)), mask + def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0): - bs, c, iy, ix = input.shape - return input / input.mul(input).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta) + bs, c, iy, ix = input.shape + return input / input.mul(input).reshape(bs, 1, c, iy * ix).pad2d( + (0, 0, (size - 1) // 2, size // 2) + ).avg_pool2d((size, 1), 1).reshape(bs, c, iy, ix).mul(alpha).add(bias).pow(beta) + def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)): - data_mean = input.mean(axis=axis, keepdim=True) - std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt() - return (input - data_mean) / (std + 1e-9) + data_mean = input.mean(axis=axis, keepdim=True) + std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt() + return (input - data_mean) / (std + 1e-9) -def NegativeLogLikelihoodLoss(input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"): - target = target.cast(dtypes.float32) - N, C, i_shape = input.shape[0], input.shape[1], input.shape - t_shape = target.shape - if len(input.shape) != 3: - input = input.reshape((N, C, -1)) - target = target.reshape((N, -1)) - if weight is not None: - mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1)) - weight = (mask * weight).sum(axis=-1) - if ignore_index is not None: - cond = target == ignore_index - weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1) - mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2)) - loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight) - if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() - elif reduction == "sum": return loss.sum() - return loss.reshape(t_shape) if len(i_shape) != 3 else loss -def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"): - N, C, *s_dimensions = scores.shape - if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels) - mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions)) - y = scores.log_softmax(axis=1) - if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)])) - loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights - if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum() - elif reduction == "sum": loss = loss.sum() - return loss, y +def NegativeLogLikelihoodLoss( + input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean" +): + target = target.cast(dtypes.float32) + N, C, i_shape = input.shape[0], input.shape[1], input.shape + t_shape = target.shape + if len(input.shape) != 3: + input = input.reshape((N, C, -1)) + target = target.reshape((N, -1)) + if weight is not None: + mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1)) + weight = (mask * weight).sum(axis=-1) + if ignore_index is not None: + cond = target == ignore_index + weight = ( + cond.where(0, weight) + if weight is not None + else cond.where(Tensor.zeros(*target.shape), 1) + ) + mask = target[:, None, :] == Tensor.arange(C).reshape( + [1, C] + [1] * (len(input.shape) - 2) + ) + loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight) + if reduction == "mean": + return loss.mean() if weight is None else loss.sum() / weight.sum() + elif reduction == "sum": + return loss.sum() + return loss.reshape(t_shape) if len(i_shape) != 3 else loss + + +def SoftmaxCrossEntropyLoss( + scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean" +): + N, C, *s_dimensions = scores.shape + if ignore_index is not None: + labels = (labels == ignore_index).where(C + 1, labels) + mask = labels.unsqueeze(1) == Tensor.arange(C).reshape( + 1, C, *[1] * len(s_dimensions) + ) + y = scores.log_softmax(axis=1) + if weights is not None: + weights = weights.__getitem__( + tuple([labels, *[slice(None)] * (weights.ndim - 1)]) + ) + loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights + if reduction == "mean": + loss = ( + loss.sum() / (loss == 0).where(0, 1).sum() + if weights is None + else loss.sum() / weights.sum() + ) + elif reduction == "sum": + loss = loss.sum() + return loss, y + + +def ArrayFeatureExtractor(input: Tensor, indices: Tensor): + return input.__getitem__( + tuple( + [ + slice(None) if i != (input.ndim - 1) else indices + for i in range(input.ndim) + ] + ) + ) + -def ArrayFeatureExtractor(input: Tensor, indices: Tensor): return input.__getitem__(tuple([slice(None) if i != (input.ndim-1) else indices for i in range(input.ndim)])) def Gather(input: Tensor, indices: Tensor, axis=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - input_sh = list(input.shape) - ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [int(safe_numpy(indices))] if indices.shape == () else [input_sh[axis]+int(x) if x<0 else int(x) for x in safe_numpy(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(input_sh)] for i in indices] - return input.shrink(arg=tuple(args[0])).cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return input.__getitem__(tuple([slice(None) if i != axis else indices for i in range(input.ndim)])) + if ( + indices.numel() < 9 + ): # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + input_sh = list(input.shape) + ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis + 1 :] + if indices.ndim > 1: + indices = indices.flatten() + indices = ( + [int(safe_numpy(indices))] + if indices.shape == () + else [ + input_sh[axis] + int(x) if x < 0 else int(x) + for x in safe_numpy(indices) + ] + ) + args = [ + [(0, x) if j != axis else (i, i + 1) for j, x in enumerate(input_sh)] + for i in indices + ] + return ( + input.shrink(arg=tuple(args[0])) + .cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis) + .reshape(ret_shape) + ) + else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return input.__getitem__( + tuple([slice(None) if i != axis else indices for i in range(input.ndim)]) + ) + def GatherElements(input: Tensor, indices: Tensor, axis): - indices = indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + indices - return input.gather(indices, axis) + indices = ( + indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis] + + indices + ) + return input.gather(indices, axis) -def _round(x:Tensor, n:float, equidistant_case = "round_down") -> Tensor: - def _and(cond1, cond2): return ((cond1 + cond2) == 2).where(1, 0) - assert n <= 1, f"n:{n} shouldn't be larger than 1" - b = x.cast(dtypes.int32).contiguous().cast(x.dtype) - b = (b >= 0).where(b+n, b-n) - if equidistant_case == "round_down": - return (x > b).where(b+1-n, b-n) - elif equidistant_case == "round_up": - return (x >= b).where(b+1-n, b-n) - elif equidistant_case == "round_to_even": - x_ceil_fraction = x.ceil()/2 - cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction - x = (_and(x == b, cond_ceil_even)).where(x+1-n, x) - x = (x > b).where(b+1-n, b-n) - return x -def Round(X:Tensor): return _round(X, 0.5, "round_to_even") +def _round(x: Tensor, n: float, equidistant_case="round_down") -> Tensor: + def _and(cond1, cond2): + return ((cond1 + cond2) == 2).where(1, 0) + + assert n <= 1, f"n:{n} shouldn't be larger than 1" + b = x.cast(dtypes.int32).contiguous().cast(x.dtype) + b = (b >= 0).where(b + n, b - n) + if equidistant_case == "round_down": + return (x > b).where(b + 1 - n, b - n) + elif equidistant_case == "round_up": + return (x >= b).where(b + 1 - n, b - n) + elif equidistant_case == "round_to_even": + x_ceil_fraction = x.ceil() / 2 + cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction + x = (_and(x == b, cond_ceil_even)).where(x + 1 - n, x) + x = (x > b).where(b + 1 - n, b - n) + return x + + +def Round(X: Tensor): + return _round(X, 0.5, "round_to_even") + # TODO clean this up, it's taking the longest in CI -def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, coordinate_transformation_mode='half_pixel', cubic_coeff_a=-0.75, exclude_outside=0, extrapolation_value=0.0, keep_aspect_ratio_policy='stretch', mode='nearest', nearest_mode='round_prefer_floor'): - def _nearest_gather(X: Tensor, x_out, y_out): return X[:,:,y_out,:][:,:,:,x_out] - def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len): - if nearest_mode == "round_prefer_floor": ret = _round(x_resized, 0.5, "round_down") - elif nearest_mode == "round_prefer_ceil": ret = _round(x_resized, 0.5, "round_up") - elif nearest_mode == "floor": ret = x_resized.floor() - elif nearest_mode == "ceil": ret = x_resized.ceil() - return ret.clip(0, x_len-1) - def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None): - if coordinate_transformation_mode == "half_pixel": - x_out = (x_out + 0.5)/Tensor(scales_lol[-1]) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy. - y_out = (y_out + 0.5)/Tensor(scales_lol[-2]) - 0.5 - elif coordinate_transformation_mode == "align_corners": - x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1) - y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1) - elif coordinate_transformation_mode == "asymmetric": - x_out = x_out/scales_lol[-1] - y_out = y_out/scales_lol[-2] - elif coordinate_transformation_mode == "half_pixel_symmetric": - x_out = X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + (x_out + 0.5) / scales_lol[-1] - 0.5 - y_out = X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + (y_out + 0.5) / scales_lol[-2] - 0.5 - elif coordinate_transformation_mode == "pytorch_half_pixel": - x_out = (x_out + 0.5)/scales_lol[-1] - 0.5 if output_shape[-1] > 1 else Tensor([0]) - y_out = (y_out + 0.5)/scales_lol[-2] - 0.5 if output_shape[-2] > 1 else Tensor([0]) - elif coordinate_transformation_mode == "tf_crop_and_resize": - x_out = roi[-1][0] * (X.shape[-1] - 1) + x_out * ((roi[-1][1] - roi[-1][0]) * (X.shape[-1] - 1) / (output_shape[-1] - 1)) if output_shape[-1] > 1 else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)]) - y_out = roi[-2][0] * (X.shape[-2] - 1) + y_out * ((roi[-2][1] - roi[-2][0]) * (X.shape[-2] - 1) / (output_shape[-2] - 1)) if output_shape[-2] > 1 else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)]) - return x_out.clip(0, X.shape[-1]-1), y_out.clip(0, X.shape[-2]-1) - if roi is not None: - roi = safe_numpy(roi) - roi = [(st,ed) for st, ed in zip(roi[:len(roi)//2], roi[len(roi)//2:])] - roi_ = [(1,1)] * 4 - if axes is not None: - for a,r in zip(axes, roi): - roi_[a] = r - roi = roi_ - if scales is not None: - scales = safe_numpy(scales).tolist() - if axes is not None: - scales_ = [1]*X.ndim - for a,s in zip(axes, scales): - scales_[a] = s - scales = scales_ - elif sizes is not None: - sizes = [int(i) for i in safe_numpy(sizes)] - scales = [] - if axes is not None: - sizes_ = [1]*X.ndim - for a,s in zip(axes, sizes): - sizes_[a] = s - scales.append(s/X.shape[a]) - sizes = sizes_ - else: scales = [si/xs for xs, si in zip(X.shape, sizes)] - if keep_aspect_ratio_policy == "not_larger": - scale = min(scales) - sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up") - sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)] - elif keep_aspect_ratio_policy == "not_smaller": - scale = max(scales) - sizes = _round(Tensor(list(X.shape[-2:]))*scale, 0.5, "round_up") - sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)] - output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)] - output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)] - scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)] - x_out = Tensor.arange(output_shape[-1]) - y_out = Tensor.arange(output_shape[-2]) - if mode == "nearest": - x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi) - x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1]) - y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1]) - return _nearest_gather(X, x_out, y_out) - elif mode == "linear": - x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape_, scales, roi) - ret = [] - for y in safe_numpy(y_out): - for x in safe_numpy(x_out): - x_floor, y_floor = int(x), int(y) - y_shrink = (0, X.shape[2]) if X.shape[2] == 1 else (y_floor, y_floor+2) if y != y_floor else (y_floor, y_floor+1) - x_shrink = (x_floor, x_floor+2) if x != x_floor else (x_floor, x_floor+1) - shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink) - corners = safe_numpy(X.shrink(shrink_args)) - x1, x2, y1, y2 = x_floor, x_floor+1, y_floor, y_floor+1 - if x == x_floor and y == y_floor: # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean? - ret.append(corners[0,0,0,0]) - elif x == x_floor: - ret.append((corners[0,0,0,0] * (y2 - y) + corners[0,0,1,0] * (y - y1)) / (y2 - y1)) - elif y == y_floor: - ret.append((corners[0,0,0,0] * (x2 - x) + corners[0,0,0,1] * (x - x1)) / (x2 - x1)) +def Resize( + X: Tensor, + roi=None, + scales=None, + sizes=None, + antialias=0, + axes=None, + coordinate_transformation_mode="half_pixel", + cubic_coeff_a=-0.75, + exclude_outside=0, + extrapolation_value=0.0, + keep_aspect_ratio_policy="stretch", + mode="nearest", + nearest_mode="round_prefer_floor", +): + def _nearest_gather(X: Tensor, x_out, y_out): + return X[:, :, y_out, :][:, :, :, x_out] + + def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len): + if nearest_mode == "round_prefer_floor": + ret = _round(x_resized, 0.5, "round_down") + elif nearest_mode == "round_prefer_ceil": + ret = _round(x_resized, 0.5, "round_up") + elif nearest_mode == "floor": + ret = x_resized.floor() + elif nearest_mode == "ceil": + ret = x_resized.ceil() + return ret.clip(0, x_len - 1) + + def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None): + if coordinate_transformation_mode == "half_pixel": + x_out = (x_out + 0.5) / Tensor( + scales_lol[-1] + ) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy. + y_out = (y_out + 0.5) / Tensor(scales_lol[-2]) - 0.5 + elif coordinate_transformation_mode == "align_corners": + x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1) + y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1) + elif coordinate_transformation_mode == "asymmetric": + x_out = x_out / scales_lol[-1] + y_out = y_out / scales_lol[-2] + elif coordinate_transformation_mode == "half_pixel_symmetric": + x_out = ( + X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1]) + + (x_out + 0.5) / scales_lol[-1] + - 0.5 + ) + y_out = ( + X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2]) + + (y_out + 0.5) / scales_lol[-2] + - 0.5 + ) + elif coordinate_transformation_mode == "pytorch_half_pixel": + x_out = ( + (x_out + 0.5) / scales_lol[-1] - 0.5 + if output_shape[-1] > 1 + else Tensor([0]) + ) + y_out = ( + (y_out + 0.5) / scales_lol[-2] - 0.5 + if output_shape[-2] > 1 + else Tensor([0]) + ) + elif coordinate_transformation_mode == "tf_crop_and_resize": + x_out = ( + roi[-1][0] * (X.shape[-1] - 1) + + x_out + * ( + (roi[-1][1] - roi[-1][0]) + * (X.shape[-1] - 1) + / (output_shape[-1] - 1) + ) + if output_shape[-1] > 1 + else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)]) + ) + y_out = ( + roi[-2][0] * (X.shape[-2] - 1) + + y_out + * ( + (roi[-2][1] - roi[-2][0]) + * (X.shape[-2] - 1) + / (output_shape[-2] - 1) + ) + if output_shape[-2] > 1 + else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)]) + ) + return x_out.clip(0, X.shape[-1] - 1), y_out.clip(0, X.shape[-2] - 1) + + if roi is not None: + roi = safe_numpy(roi) + roi = [(st, ed) for st, ed in zip(roi[: len(roi) // 2], roi[len(roi) // 2 :])] + roi_ = [(1, 1)] * 4 + if axes is not None: + for a, r in zip(axes, roi): + roi_[a] = r + roi = roi_ + if scales is not None: + scales = safe_numpy(scales).tolist() + if axes is not None: + scales_ = [1] * X.ndim + for a, s in zip(axes, scales): + scales_[a] = s + scales = scales_ + elif sizes is not None: + sizes = [int(i) for i in safe_numpy(sizes)] + scales = [] + if axes is not None: + sizes_ = [1] * X.ndim + for a, s in zip(axes, sizes): + sizes_[a] = s + scales.append(s / X.shape[a]) + sizes = sizes_ else: - ret.append((corners[0,0,0,0] * (x2 - x) * (y2 - y) + corners[0,0,0,1] * (x - x1) * (y2 - y) + corners[0,0,1,0] * (x2 - x) * (y - y1) + corners[0,0,1,1] * (x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1))) - return Tensor(ret).reshape(output_shape) - elif mode == "cubic": - raise Exception("cubic interpolation is not implemented") + scales = [si / xs for xs, si in zip(X.shape, sizes)] + if keep_aspect_ratio_policy == "not_larger": + scale = min(scales) + sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up") + sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)] + elif keep_aspect_ratio_policy == "not_smaller": + scale = max(scales) + sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up") + sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)] + output_shape = ( + sizes if sizes else [math.floor(x * s) for x, s in zip(X.shape, scales)] + ) + output_shape_ = sizes if sizes else [x * s for x, s in zip(X.shape, scales)] + scales_lol = [os / xs for xs, os in zip(X.shape, output_shape)] + x_out = Tensor.arange(output_shape[-1]) + y_out = Tensor.arange(output_shape[-2]) + if mode == "nearest": + x_out, y_out = _coordinate_transformation( + x_out, y_out, output_shape, scales_lol, roi + ) + x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1]) + y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1]) + return _nearest_gather(X, x_out, y_out) + elif mode == "linear": + x_out, y_out = _coordinate_transformation( + x_out, y_out, output_shape_, scales, roi + ) + ret = [] + for y in safe_numpy(y_out): + for x in safe_numpy(x_out): + x_floor, y_floor = int(x), int(y) + y_shrink = ( + (0, X.shape[2]) + if X.shape[2] == 1 + else (y_floor, y_floor + 2) + if y != y_floor + else (y_floor, y_floor + 1) + ) + x_shrink = ( + (x_floor, x_floor + 2) if x != x_floor else (x_floor, x_floor + 1) + ) + shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink) + corners = safe_numpy(X.shrink(shrink_args)) + x1, x2, y1, y2 = x_floor, x_floor + 1, y_floor, y_floor + 1 + if ( + x == x_floor and y == y_floor + ): # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean? + ret.append(corners[0, 0, 0, 0]) + elif x == x_floor: + ret.append( + ( + corners[0, 0, 0, 0] * (y2 - y) + + corners[0, 0, 1, 0] * (y - y1) + ) + / (y2 - y1) + ) + elif y == y_floor: + ret.append( + ( + corners[0, 0, 0, 0] * (x2 - x) + + corners[0, 0, 0, 1] * (x - x1) + ) + / (x2 - x1) + ) + else: + ret.append( + ( + corners[0, 0, 0, 0] * (x2 - x) * (y2 - y) + + corners[0, 0, 0, 1] * (x - x1) * (y2 - y) + + corners[0, 0, 1, 0] * (x2 - x) * (y - y1) + + corners[0, 0, 1, 1] * (x - x1) * (y - y1) + ) + / ((x2 - x1) * (y2 - y1)) + ) + return Tensor(ret).reshape(output_shape) + elif mode == "cubic": + raise Exception("cubic interpolation is not implemented") + def CenterCropPad(input: Tensor, shape: Tensor, axes=None): - if not axes: axes = list(range(input.ndim)) - shrink_arg = [(0,i) for i in input.shape] - pad_arg = [(0,0) for _ in range(input.ndim)] - shape = safe_numpy(shape).tolist() - for s, x in zip(shape, axes): - if s < input.shape[x]: shrink_arg[x] = (input.shape[x]//2 - s//2, input.shape[x]//2 + s//2) if s%2 == 0 else (input.shape[x]//2 - s//2 - 1, input.shape[x]//2 + s//2) - elif s > input.shape[x]: pad_arg[x] = ((s - input.shape[x])//2, (s - input.shape[x])//2) if (s - input.shape[x])% 2 == 0 else ((s - input.shape[x])//2, (s - input.shape[x])//2 + 1) - return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + if not axes: + axes = list(range(input.ndim)) + shrink_arg = [(0, i) for i in input.shape] + pad_arg = [(0, 0) for _ in range(input.ndim)] + shape = safe_numpy(shape).tolist() + for s, x in zip(shape, axes): + if s < input.shape[x]: + shrink_arg[x] = ( + (input.shape[x] // 2 - s // 2, input.shape[x] // 2 + s // 2) + if s % 2 == 0 + else (input.shape[x] // 2 - s // 2 - 1, input.shape[x] // 2 + s // 2) + ) + elif s > input.shape[x]: + pad_arg[x] = ( + ((s - input.shape[x]) // 2, (s - input.shape[x]) // 2) + if (s - input.shape[x]) % 2 == 0 + else ((s - input.shape[x]) // 2, (s - input.shape[x]) // 2 + 1) + ) + return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1): - depth = int(safe_numpy(depth).item()) - indices, rank = (indices < 0).where(indices+depth, indices), len(indices.shape) - if axis < 0: axis += rank + 1 - ls, rs = indices.shape[0:axis], indices.shape[axis: rank] - cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs)) - return cond.where(values[1], values[0]).cast(values.dtype) + depth = int(safe_numpy(depth).item()) + indices, rank = (indices < 0).where(indices + depth, indices), len(indices.shape) + if axis < 0: + axis += rank + 1 + ls, rs = indices.shape[0:axis], indices.shape[axis:rank] + cond = indices[:, None] == Tensor.arange(depth).reshape( + (1,) * len(ls) + (depth,) + (1,) * len(rs) + ) + return cond.where(values[1], values[0]).cast(values.dtype) + def Erf(x: Tensor): - sign = x.sign() - x = x.abs() - t = 1.0 / (1.0 + 0.3275911 * x) - term1 = 0.254829592 * t - term2 = -0.284496736 * t ** 2 - term3 = 1.421413741 * t ** 3 - term4 = -1.453152027 * t ** 4 - term5 = 1.061405429 * t ** 5 - y = (term1 + term2 + term3 + term4 + term5) - return sign * (1.0 - y * Tensor.exp(-x * x)) + sign = x.sign() + x = x.abs() + t = 1.0 / (1.0 + 0.3275911 * x) + term1 = 0.254829592 * t + term2 = -0.284496736 * t**2 + term3 = 1.421413741 * t**3 + term4 = -1.453152027 * t**4 + term5 = 1.061405429 * t**5 + y = term1 + term2 + term3 + term4 + term5 + return sign * (1.0 - y * Tensor.exp(-x * x)) + def Compress(inp: Tensor, condition: Tensor, axis=None): - if axis is None: - inp = inp.flatten() - axis = 0 + if axis is None: + inp = inp.flatten() + axis = 0 - axis = axis + inp.ndim if axis < 0 else axis + axis = axis + inp.ndim if axis < 0 else axis + + con_np = safe_numpy(condition) + con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor + return inp.__getitem__( + tuple([slice(None) if i != axis else con for i in range(inp.ndim)]) + ) - con_np = safe_numpy(condition) - con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor - return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)])) type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32} -def EyeLike(x: Tensor, dtype=None, k=0): - if dtype is None: dtype = x.dtype - else: dtype = type_map[dtype] - shape = x.shape - dim = min(x.shape) - if shape[0] == shape[1]: return Tensor.eye(dim=dim, dtype=dtype) - else: - diff = (shape[0]-dim, shape[1]-dim) - padarg = tuple([(d, d) if d == 0 else (k, d-k) for d in diff]) - return Tensor.eye(dim=dim, dtype=dtype).pad(padarg) -def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) + +def EyeLike(x: Tensor, dtype=None, k=0): + if dtype is None: + dtype = x.dtype + else: + dtype = type_map[dtype] + shape = x.shape + dim = min(x.shape) + if shape[0] == shape[1]: + return Tensor.eye(dim=dim, dtype=dtype) + else: + diff = (shape[0] - dim, shape[1] - dim) + padarg = tuple([(d, d) if d == 0 else (k, d - k) for d in diff]) + return Tensor.eye(dim=dim, dtype=dtype).pad(padarg) + + +def Upsample(X, scales, mode): + return Resize(X=X, scales=scales, mode=mode) + # Needs work def IsInf(x: Tensor, detect_negative=1, detect_positive=1): - ret = (x == float("inf"))*detect_positive + (x == float("-inf"))*detect_negative + Tensor.zeros(*x.shape) - return ret.cast(dtypes.bool) + ret = ( + (x == float("inf")) * detect_positive + + (x == float("-inf")) * detect_negative + + Tensor.zeros(*x.shape) + ) + return ret.cast(dtypes.bool) + + +def DequantizeLinear( + x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1 +): + axis = axis + x.ndim if axis < 0 else axis + x = x.cast(dtypes.float) + if x_zero_point.__class__ is Tensor: + x_zero_point.cast(dtypes.float) + x_sc = x_scale.reshape( + *[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim) + ) + x_zer = ( + x_zero_point.reshape( + *[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim) + ) + if isinstance(x_zero_point, Tensor) + else x_zero_point + ) + return ((x - x_zer) * x_sc).cast(x_scale.dtype) -def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1): - axis = axis + x.ndim if axis < 0 else axis - x = x.cast(dtypes.float) - if x_zero_point.__class__ is Tensor: x_zero_point.cast(dtypes.float) - x_sc = x_scale.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) - x_zer = x_zero_point.reshape(*[1]*axis, *x_scale.shape, *[1]*(x.ndim - axis - x_scale.ndim)) if isinstance(x_zero_point, Tensor) else x_zero_point - return ((x - x_zer) * x_sc).cast(x_scale.dtype) # Needs work def IsNaN(x: Tensor): - return (x < float("-inf")).cast(dtypes.bool) + return (x < float("-inf")).cast(dtypes.bool) + # copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py # without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"): - try: - import PIL.Image - except ImportError as e: - raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(safe_numpy(encoded_stream).tobytes())) - if pixel_format == "BGR": - return Tensor(np.array(img))[:, :, ::-1] - elif pixel_format == "RGB": - return Tensor(np.array(img)) - elif pixel_format == "Grayscale": - img = img.convert("L") - decoded = Tensor(np.array(img)) - return decoded.unsqueeze(-1) # (H, W) to (H, W, 1) - else: - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + try: + import PIL.Image + except ImportError as e: + raise ImportError( + "Pillow must be installed to use the reference implementation of the ImageDecoder operator" + ) from e + img = PIL.Image.open(io.BytesIO(safe_numpy(encoded_stream).tobytes())) + if pixel_format == "BGR": + return Tensor(np.array(img))[:, :, ::-1] + elif pixel_format == "RGB": + return Tensor(np.array(img)) + elif pixel_format == "Grayscale": + img = img.convert("L") + decoded = Tensor(np.array(img)) + return decoded.unsqueeze(-1) # (H, W) to (H, W, 1) + else: + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + def AffineGrid(theta: Tensor, size: Tensor, align_corners=0): - _, _, *data_sz = safe_numpy(size).tolist() - size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz) - stackable = [original_grid] - for dim, dim_sz in enumerate(data_sz): - a = Tensor.arange(-1, 1.0001, 2/(dim_sz-1)) if align_corners == 1 else Tensor.arange(-1+1/dim_sz, 1, 2/dim_sz) - if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable] - elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable] - else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable] - original_grid = Tensor.stack(stackable, dim=len(data_sz)) - if original_grid.ndim == 3: - N, dim_2d, dim_homo = theta.shape - assert dim_2d == 2 and dim_homo == 3 - H, W, dim_homo = original_grid.shape - assert dim_homo == 3 - original_grid = original_grid.reshape(H*W, dim_homo).transpose() - return theta.matmul(original_grid).permute(0,2,1).reshape(N, H, W, dim_2d) - else: - assert original_grid.ndim == 4 - N, dim_3d, dim_homo = theta.shape - assert dim_3d == 3 and dim_homo == 4 - D, H, W, dim_homo = original_grid.shape - assert dim_homo == 4 - original_grid = original_grid.reshape(D*H*W, dim_homo).transpose() - return theta.matmul(original_grid).permute(0,2,1).reshape(N, D, H, W, dim_3d) + _, _, *data_sz = safe_numpy(size).tolist() + size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz) + stackable = [original_grid] + for dim, dim_sz in enumerate(data_sz): + a = ( + Tensor.arange(-1, 1.0001, 2 / (dim_sz - 1)) + if align_corners == 1 + else Tensor.arange(-1 + 1 / dim_sz, 1, 2 / dim_sz) + ) + if dim == 0: + stackable = [ + a.reshape(dim_sz, *[1] * (len(data_sz) - 1)) + size_zeros, + *stackable, + ] + elif dim == 1: + stackable = [ + a.reshape(1, dim_sz, *[1] * (len(data_sz) - 2)) + size_zeros, + *stackable, + ] + else: + stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable] + original_grid = Tensor.stack(stackable, dim=len(data_sz)) + if original_grid.ndim == 3: + N, dim_2d, dim_homo = theta.shape + assert dim_2d == 2 and dim_homo == 3 + H, W, dim_homo = original_grid.shape + assert dim_homo == 3 + original_grid = original_grid.reshape(H * W, dim_homo).transpose() + return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, H, W, dim_2d) + else: + assert original_grid.ndim == 4 + N, dim_3d, dim_homo = theta.shape + assert dim_3d == 3 and dim_homo == 4 + D, H, W, dim_homo = original_grid.shape + assert dim_homo == 4 + original_grid = original_grid.reshape(D * H * W, dim_homo).transpose() + return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, D, H, W, dim_3d) + # **************** com.microsoft Ops **************** -def SkipLayerNormalization(input:Tensor, skip:Tensor, gamma, beta:Optional[Tensor]=None, bias:Optional[Tensor]=None, epsilon=None): - if epsilon is None: epsilon=1e-12 - x = input + skip + bias - return x.layernorm(eps=epsilon) * gamma + beta, None, None, x -def FastGelu(x:Tensor, bias:Optional[Tensor]=None): - x = x + bias - return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x ** 3).tanh()) +def SkipLayerNormalization( + input: Tensor, + skip: Tensor, + gamma, + beta: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + epsilon=None, +): + if epsilon is None: + epsilon = 1e-12 + x = input + skip + bias + return x.layernorm(eps=epsilon) * gamma + beta, None, None, x -def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None, word_embedding:Tensor=None, position_embedding:Tensor=None, segment_embedding:Optional[Tensor]=None, gamma=None, beta=None, mask:Optional[Tensor]=None, position_ids:Optional[Tensor]=None, epsilon=None, mask_index_type=None): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization - assert (segment_ids is None) is (segment_embedding is None) - assert (mask is None) is (mask_index_type is None) - assert mask is None, "functionality not supported yet" # TODO - input_shape = input_ids.shape - bsz, seq_length = input_shape[0], input_shape[1] - compute_seg_emb = (segment_embedding is not None and segment_ids is not None) - vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None) - def embedding(x:Tensor, vocab_size, weight:Tensor)->Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor - vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).reshape(1, 1, vocab_size).expand(*x.shape, vocab_size) - return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight +def FastGelu(x: Tensor, bias: Optional[Tensor] = None): + x = x + bias + return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x**3).tanh()) - # bert embedding layer - if epsilon is None: epsilon = 1e-12 - if position_ids is None: position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape) - wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) - pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) - seg_embedding_res = embedding(segment_ids, type_vocab_size, segment_embedding) if compute_seg_emb else None - embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res - out = embedding_sum.layernorm(eps=epsilon) * gamma + beta - return out, None, embedding_sum +def EmbedLayerNormalization( + input_ids: Tensor, + segment_ids: Optional[Tensor] = None, + word_embedding: Tensor = None, + position_embedding: Tensor = None, + segment_embedding: Optional[Tensor] = None, + gamma=None, + beta=None, + mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + epsilon=None, + mask_index_type=None, +): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization + assert (segment_ids is None) is (segment_embedding is None) + assert (mask is None) is (mask_index_type is None) + assert mask is None, "functionality not supported yet" # TODO + input_shape = input_ids.shape + bsz, seq_length = input_shape[0], input_shape[1] + compute_seg_emb = segment_embedding is not None and segment_ids is not None + vocab_size, max_position_embeddings, type_vocab_size = ( + word_embedding.shape[0], + position_embedding.shape[0], + (segment_embedding.shape[0] if compute_seg_emb else None), + ) -def Attention(input:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional[Tensor]=None, past:Optional[Tensor]=None, relative_position_bias:Optional[Tensor]=None, past_sequence_length:Optional[Tensor]=None, do_rotary=None, mask_filter_value=None, num_heads=None, past_present_share_buffer=None, qkv_hidden_sizes=None, scale=None, unidirectional=None): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention - assert num_heads is not None # required - assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) - assert relative_position_bias==do_rotary==past_sequence_length==mask_filter_value==past_present_share_buffer==scale==None, "functionality not supported yet" # TODO strange params - hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) + def embedding( + x: Tensor, vocab_size, weight: Tensor + ) -> Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor + vocab_counter = ( + Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False) + .reshape(1, 1, vocab_size) + .expand(*x.shape, vocab_size) + ) + return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight - if unidirectional: # gpt-style - assert hidden_size == v_hidden_size - xqkv = input.linear(weights, bias) - xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] - else: # bert-style - wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] - bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None - xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] - xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] + # bert embedding layer + if epsilon is None: + epsilon = 1e-12 + if position_ids is None: + position_ids = ( + Tensor.arange(seq_length, requires_grad=False) + .unsqueeze(0) + .expand(*input_shape) + ) + wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) + pos_embedding_res = embedding( + position_ids, max_position_embeddings, position_embedding + ) + seg_embedding_res = ( + embedding(segment_ids, type_vocab_size, segment_embedding) + if compute_seg_emb + else None + ) - if past is not None: - xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) - present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res + out = embedding_sum.layernorm(eps=epsilon) * gamma + beta + return out, None, embedding_sum - def attn(query, key, value, attn_mask): - query_length, key_length = query.shape[-2], key.shape[-2] - cdim = max(query_length, key_length) + 1 - attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) - # This is where Tensor.scaled_dot_product_attention differs: - causal_mask = Tensor.ones((cdim, cdim), requires_grad=False).cast(dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length].cast(dtypes.bool) - return (Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask).softmax(-1) @ value - bsz, _, seq_len, _ = xq.shape - out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) - return out, present +def Attention( + input: Tensor, + weights, + bias: Optional[Tensor] = None, + mask_index: Optional[Tensor] = None, + past: Optional[Tensor] = None, + relative_position_bias: Optional[Tensor] = None, + past_sequence_length: Optional[Tensor] = None, + do_rotary=None, + mask_filter_value=None, + num_heads=None, + past_present_share_buffer=None, + qkv_hidden_sizes=None, + scale=None, + unidirectional=None, +): + # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention + assert num_heads is not None # required + assert (qkv_hidden_sizes is None and past is not None) or ( + qkv_hidden_sizes is not None + ) + assert ( + relative_position_bias + == do_rotary + == past_sequence_length + == mask_filter_value + == past_present_share_buffer + == scale + == None + ), "functionality not supported yet" # TODO strange params + hidden_size, v_hidden_size = ( + qkv_hidden_sizes[1:] + if qkv_hidden_sizes is not None + else 2 * (weights.shape[1] // 3,) + ) + + if unidirectional: # gpt-style + assert hidden_size == v_hidden_size + xqkv = input.linear(weights, bias) + xq, xk, xv = [ + xqkv.slice([None, None, (i * hidden_size, (i + 1) * hidden_size)]) + for i in range(3) + ] + else: # bert-style + wq, wk, wv = ( + weights[:, :hidden_size], + weights[:, hidden_size : hidden_size + v_hidden_size], + weights[:, hidden_size + v_hidden_size :], + ) + bq, bk, bv = ( + ( + bias[:hidden_size], + bias[hidden_size : hidden_size + v_hidden_size], + bias[hidden_size + v_hidden_size], + ) + if bias is not None + else None + ) + xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] + xq, xk, xv = [ + x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) + for x in (xq, xk, xv) + ] + + if past is not None: + xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) + present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + + def attn(query, key, value, attn_mask): + query_length, key_length = query.shape[-2], key.shape[-2] + cdim = max(query_length, key_length) + 1 + attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) + # This is where Tensor.scaled_dot_product_attention differs: + causal_mask = ( + Tensor.ones((cdim, cdim), requires_grad=False) + .cast(dtypes.bool) + .tril(0)[key_length - query_length : key_length, :key_length] + .cast(dtypes.bool) + ) + return ( + Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask + ).softmax(-1) @ value + + bsz, _, seq_len, _ = xq.shape + out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) + return out, present + # **************** ai.onnx.preview.training Ops **************** + # TODO not entirely sure these optimizers are correct def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0): - groups = len(inputs) // 3 - grouped_inputs = [inputs[i::groups] for i in range(groups)] - T, R = safe_numpy(T)[0], safe_numpy(R)[0] - r = R / (1 + T * decay_factor) - ret = [] - for input in grouped_inputs: - X, G, H = input - X.grad = norm_coefficient * X + G - X.grad.requires_grad, H.requires_grad = False, False # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py - H.assign(H.detach() + X.grad * X.grad).realize() - H_adaptive = H.sqrt() + epsilon - X.assign(X.detach() - r * X.grad / H_adaptive) - ret.extend([X, H]) - ret = ret[::2] + ret[1::2] - return tuple(ret) + groups = len(inputs) // 3 + grouped_inputs = [inputs[i::groups] for i in range(groups)] + T, R = safe_numpy(T)[0], safe_numpy(R)[0] + r = R / (1 + T * decay_factor) + ret = [] + for input in grouped_inputs: + X, G, H = input + X.grad = norm_coefficient * X + G + X.grad.requires_grad, H.requires_grad = ( + False, + False, + ) # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py + H.assign(H.detach() + X.grad * X.grad).realize() + H_adaptive = H.sqrt() + epsilon + X.assign(X.detach() - r * X.grad / H_adaptive) + ret.extend([X, H]) + ret = ret[::2] + ret[1::2] + return tuple(ret) + def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient): - groups = len(inputs) // 3 - grouped_inputs = [inputs[i::groups] for i in range(groups)] - T, R = safe_numpy(T)[0], safe_numpy(R)[0] - beta_adjusted = beta if T > 0 else 1 - ret = [] - for input in grouped_inputs: - X, G, V = input - X.grad = (norm_coefficient * X + G).realize() - X.grad.requires_grad, V.requires_grad = False, False - V.assign(alpha * V + beta_adjusted * X.grad).realize() - if mode == "standard": X.assign(X.detach() - R * V).realize() - elif mode == "nesterov": X.assign(X.detach() - R * (X.grad + alpha + V)).realize() - ret.extend([X, V]) - ret = ret[::2] + ret[1::2] - return tuple(ret) + groups = len(inputs) // 3 + grouped_inputs = [inputs[i::groups] for i in range(groups)] + T, R = safe_numpy(T)[0], safe_numpy(R)[0] + beta_adjusted = beta if T > 0 else 1 + ret = [] + for input in grouped_inputs: + X, G, V = input + X.grad = (norm_coefficient * X + G).realize() + X.grad.requires_grad, V.requires_grad = False, False + V.assign(alpha * V + beta_adjusted * X.grad).realize() + if mode == "standard": + X.assign(X.detach() - R * V).realize() + elif mode == "nesterov": + X.assign(X.detach() - R * (X.grad + alpha + V)).realize() + ret.extend([X, V]) + ret = ret[::2] + ret[1::2] + return tuple(ret) + # copied from tinygrad/nn/optim.py: LAMB with some edits -def Adam(R, T, *inputs, alpha=0.9, beta=0.999, epsilon=0.0, norm_coefficient=0.0, norm_coefficient_post=0.0): - groups = len(inputs) // 4 - grouped_inputs = [inputs[i::groups] for i in range(groups)] - T, R = safe_numpy(T)[0], safe_numpy(R)[0] - ret = [] - for input in grouped_inputs: - X, G, V, H = input - X.grad = (norm_coefficient * X + G).realize() - V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False - V.assign(alpha * V + (1.0 - alpha) * X.grad).realize() - H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize() - up = (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) if T > 0 else V / (H.sqrt() + epsilon) - X.assign(X.detach() - R * up).realize() - X = (1 - norm_coefficient_post) * X - ret.extend([X, V, H]) - ret = ret[::3] + ret[1::3] + ret[2::3] - return tuple(ret) +def Adam( + R, + T, + *inputs, + alpha=0.9, + beta=0.999, + epsilon=0.0, + norm_coefficient=0.0, + norm_coefficient_post=0.0, +): + groups = len(inputs) // 4 + grouped_inputs = [inputs[i::groups] for i in range(groups)] + T, R = safe_numpy(T)[0], safe_numpy(R)[0] + ret = [] + for input in grouped_inputs: + X, G, V, H = input + X.grad = (norm_coefficient * X + G).realize() + V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False + V.assign(alpha * V + (1.0 - alpha) * X.grad).realize() + H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize() + up = ( + (V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon) + if T > 0 + else V / (H.sqrt() + epsilon) + ) + X.assign(X.detach() - R * up).realize() + X = (1 - norm_coefficient_post) * X + ret.extend([X, V, H]) + ret = ret[::3] + ret[1::3] + ret[2::3] + return tuple(ret) diff --git a/extra/optimization/extract_policynet.py b/extra/optimization/extract_policynet.py index 149aacf2a..79b91e220 100644 --- a/extra/optimization/extract_policynet.py +++ b/extra/optimization/extract_policynet.py @@ -4,111 +4,148 @@ from copy import deepcopy from tinygrad.nn import Linear from tinygrad.tensor import Tensor from tinygrad.nn.optim import Adam -from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict +from tinygrad.nn.state import ( + get_parameters, + get_state_dict, + safe_save, + safe_load, + load_state_dict, +) from tinygrad.features.search import actions -from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin +from extra.optimization.helpers import ( + load_worlds, + ast_str_to_lin, + lin_to_feats, + assert_same_lin, +) from tinygrad.codegen.linearizer import Linearizer from tinygrad.helpers import getenv # stuff needed to unpack a kernel -from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.ops import ( + LazyOp, + TernaryOps, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, +) from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") from tinygrad.codegen.kernel import Opt, OptOps INNER = 256 + + class PolicyNet: - def __init__(self): - self.l1 = Linear(1021,INNER) - self.l2 = Linear(INNER,INNER) - self.l3 = Linear(INNER,1+len(actions)) - def __call__(self, x): - x = self.l1(x).relu() - x = self.l2(x).relu().dropout(0.9) - return self.l3(x).log_softmax() + def __init__(self): + self.l1 = Linear(1021, INNER) + self.l2 = Linear(INNER, INNER) + self.l3 = Linear(INNER, 1 + len(actions)) + + def __call__(self, x): + x = self.l1(x).relu() + x = self.l2(x).relu().dropout(0.9) + return self.l3(x).log_softmax() + def dataset_from_cache(fn): - conn = sqlite3.connect(fn) - cur = conn.cursor() - cur.execute("SELECT * FROM beam_search") - X,A = [], [] - for f in tqdm(cur.fetchall()): - Xs,As = [], [] - try: - lin = Linearizer(eval(f[0])) - opts = pickle.loads(f[-1]) - for o in opts: - Xs.append(lin_to_feats(lin, use_sts=True)) - As.append(actions.index(o)) - lin.apply_opt(o) - Xs.append(lin_to_feats(lin, use_sts=True)) - As.append(0) - except Exception: - pass - X += Xs - A += As - return X,A + conn = sqlite3.connect(fn) + cur = conn.cursor() + cur.execute("SELECT * FROM beam_search") + X, A = [], [] + for f in tqdm(cur.fetchall()): + Xs, As = [], [] + try: + lin = Linearizer(eval(f[0])) + opts = pickle.loads(f[-1]) + for o in opts: + Xs.append(lin_to_feats(lin, use_sts=True)) + As.append(actions.index(o)) + lin.apply_opt(o) + Xs.append(lin_to_feats(lin, use_sts=True)) + As.append(0) + except Exception: + pass + X += Xs + A += As + return X, A + if __name__ == "__main__": - if getenv("REGEN"): - X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache") - safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset_policy") - else: - ld = safe_load("/tmp/dataset_policy") - X,V = ld['X'].numpy(), ld['V'].numpy() + if getenv("REGEN"): + X, V = dataset_from_cache( + sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache" + ) + safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset_policy") + else: + ld = safe_load("/tmp/dataset_policy") + X, V = ld["X"].numpy(), ld["V"].numpy() - print(X.shape, V.shape) - order = list(range(X.shape[0])) - random.shuffle(order) - X, V = X[order], V[order] + print(X.shape, V.shape) + order = list(range(X.shape[0])) + random.shuffle(order) + X, V = X[order], V[order] - ratio = -256 - X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:]) - X,V = X[:ratio], V[:ratio] - print(X.shape, V.shape) + ratio = -256 + X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:]) + X, V = X[:ratio], V[:ratio] + print(X.shape, V.shape) - net = PolicyNet() - #if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) - optim = Adam(get_parameters(net)) + net = PolicyNet() + # if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) + optim = Adam(get_parameters(net)) - def get_minibatch(X,Y,bs): - xs, ys = [], [] - for _ in range(bs): - sel = random.randint(0, len(X)-1) - xs.append(X[sel]) - ys.append(Y[sel]) - return Tensor(xs), Tensor(ys) + def get_minibatch(X, Y, bs): + xs, ys = [], [] + for _ in range(bs): + sel = random.randint(0, len(X) - 1) + xs.append(X[sel]) + ys.append(Y[sel]) + return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True - losses = [] - test_losses = [] - test_accuracy = 0 - test_loss = float('inf') - for i in (t:=trange(500)): - x,y = get_minibatch(X,V,bs=256) - out = net(x) - loss = out.sparse_categorical_crossentropy(y) - optim.zero_grad() - loss.backward() - optim.step() - cat = out.argmax(axis=-1) - accuracy = (cat == y).mean() - t.set_description(f"loss {loss.numpy():7.2f} accuracy {accuracy.numpy()*100:7.2f}%, test loss {test_loss:7.2f} test accuracy {test_accuracy*100:7.2f}%") + Tensor.no_grad, Tensor.training = False, True + losses = [] + test_losses = [] + test_accuracy = 0 + test_loss = float("inf") + for i in (t := trange(500)): + x, y = get_minibatch(X, V, bs=256) + out = net(x) + loss = out.sparse_categorical_crossentropy(y) + optim.zero_grad() + loss.backward() + optim.step() + cat = out.argmax(axis=-1) + accuracy = (cat == y).mean() + t.set_description( + f"loss {loss.numpy():7.2f} accuracy {accuracy.numpy()*100:7.2f}%, test loss {test_loss:7.2f} test accuracy {test_accuracy*100:7.2f}%" + ) - losses.append(loss.numpy().item()) - test_losses.append(test_loss) - if i % 10: - out = net(X_test) - test_loss = out.sparse_categorical_crossentropy(V_test).square().mean().numpy().item() - cat = out.argmax(axis=-1) - test_accuracy = (cat == y).mean().numpy() + losses.append(loss.numpy().item()) + test_losses.append(test_loss) + if i % 10: + out = net(X_test) + test_loss = ( + out.sparse_categorical_crossentropy(V_test) + .square() + .mean() + .numpy() + .item() + ) + cat = out.argmax(axis=-1) + test_accuracy = (cat == y).mean().numpy() - safe_save(get_state_dict(net), "/tmp/policynet.safetensors") + safe_save(get_state_dict(net), "/tmp/policynet.safetensors") - import matplotlib.pyplot as plt - plt.plot(losses[10:]) - plt.plot(test_losses[10:]) - plt.show() + import matplotlib.pyplot as plt + + plt.plot(losses[10:]) + plt.plot(test_losses[10:]) + plt.show() diff --git a/extra/optimization/extract_sa_pairs.py b/extra/optimization/extract_sa_pairs.py index 5a806e0a0..01ac09c0c 100644 --- a/extra/optimization/extract_sa_pairs.py +++ b/extra/optimization/extract_sa_pairs.py @@ -4,12 +4,22 @@ from tqdm import tqdm, trange import numpy as np # stuff needed to unpack a kernel -from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.ops import ( + LazyOp, + TernaryOps, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, +) from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") from tinygrad.codegen.kernel import Opt, OptOps # more stuff @@ -18,112 +28,132 @@ from tinygrad.features.search import actions from extra.optimization.helpers import lin_to_feats from extra.optimization.pretrain_valuenet import ValueNet from tinygrad.nn.optim import Adam -from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict +from tinygrad.nn.state import ( + get_parameters, + get_state_dict, + safe_save, + safe_load, + load_state_dict, +) import random from tinygrad.tensor import Tensor from tinygrad.helpers import getenv + def dataset_from_cache(fn): - conn = sqlite3.connect(fn) - cur = conn.cursor() - cur.execute("SELECT * FROM time_linearizer") - grouped = defaultdict(dict) - for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1]) + conn = sqlite3.connect(fn) + cur = conn.cursor() + cur.execute("SELECT * FROM time_linearizer") + grouped = defaultdict(dict) + for f in tqdm(cur.fetchall()): + grouped[f[0]][f[1:-1]] = pickle.loads(f[-1]) - opts_to_outcome = {} + opts_to_outcome = {} - for ast,sk in grouped.items(): - cnts = defaultdict(int) - for sks,tm in sk.items(): - if sks[1] != 1: continue - opts = eval(sks[0]) - cnts[(len(opts), sks[1])] += 1 - opts_to_outcome[(ast, tuple(opts))] = tm - #print(cnts) + for ast, sk in grouped.items(): + cnts = defaultdict(int) + for sks, tm in sk.items(): + if sks[1] != 1: + continue + opts = eval(sks[0]) + cnts[(len(opts), sks[1])] += 1 + opts_to_outcome[(ast, tuple(opts))] = tm + # print(cnts) - S,A,V = [], [], [] - for ast,k in tqdm(opts_to_outcome): - if len(k) == 0: continue - old_tm = min(opts_to_outcome[(ast,k[:-1])]) - new_tm = min(opts_to_outcome[(ast,k)]) - if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue - try: - lin = Linearizer(eval(ast)) - except Exception: - continue - for opt in k[:-1]: lin.apply_opt(opt) - act = k[-1] - log_ratio = math.log(old_tm/new_tm) - #print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}") - S.append(lin_to_feats(lin, use_sts=True)) - A.append(actions.index(act)) - V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim + S, A, V = [], [], [] + for ast, k in tqdm(opts_to_outcome): + if len(k) == 0: + continue + old_tm = min(opts_to_outcome[(ast, k[:-1])]) + new_tm = min(opts_to_outcome[(ast, k)]) + if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: + continue + try: + lin = Linearizer(eval(ast)) + except Exception: + continue + for opt in k[:-1]: + lin.apply_opt(opt) + act = k[-1] + log_ratio = math.log(old_tm / new_tm) + # print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}") + S.append(lin_to_feats(lin, use_sts=True)) + A.append(actions.index(act)) + V.append( + [log_ratio] + ) # NOTE: i have written the bug many times with this having the wrong dim - S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32) - X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32) - X[:, :S.shape[1]] = S - X[range(S.shape[0]), S.shape[1]+A] = 1.0 - return X, V + S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32) + X = np.zeros((S.shape[0], S.shape[1] + len(actions)), dtype=np.float32) + X[:, : S.shape[1]] = S + X[range(S.shape[0]), S.shape[1] + A] = 1.0 + return X, V + + +def log_likelihood(x: Tensor, mu: Tensor, log_sigma: Tensor): + # print(x.shape, mu.shape, log_sigma.shape) + # return (x-mu).abs() * (-log_sigma).exp() + log_sigma + return (x - mu).square() * (-2 * log_sigma).exp() / 2 + log_sigma -def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor): - #print(x.shape, mu.shape, log_sigma.shape) - #return (x-mu).abs() * (-log_sigma).exp() + log_sigma - return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma if __name__ == "__main__": - if getenv("REGEN"): - X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache") - safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset") - else: - ld = safe_load("/tmp/dataset") - X,V = ld['X'].numpy(), ld['V'].numpy() + if getenv("REGEN"): + X, V = dataset_from_cache( + sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache" + ) + safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset") + else: + ld = safe_load("/tmp/dataset") + X, V = ld["X"].numpy(), ld["V"].numpy() - print(X.shape, V.shape) - order = list(range(X.shape[0])) - random.shuffle(order) - X, V = X[order], V[order] + print(X.shape, V.shape) + order = list(range(X.shape[0])) + random.shuffle(order) + X, V = X[order], V[order] - ratio = -512 - X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:]) - X,V = X[:ratio], V[:ratio] - print(X.shape, V.shape) + ratio = -512 + X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:]) + X, V = X[:ratio], V[:ratio] + print(X.shape, V.shape) - #print(X[0], V[0]) - #print(X[-1], V[-1]) - print(X.shape) + # print(X[0], V[0]) + # print(X[-1], V[-1]) + print(X.shape) - net = ValueNet(X.shape[1], 2) - optim = Adam(get_parameters(net)) + net = ValueNet(X.shape[1], 2) + optim = Adam(get_parameters(net)) - def get_minibatch(X,Y,bs): - xs, ys = [], [] - #random.seed(1337) - for _ in range(bs): - sel = random.randint(0, len(X)-1) - xs.append(X[sel]) - ys.append(Y[sel]) - return Tensor(xs), Tensor(ys) + def get_minibatch(X, Y, bs): + xs, ys = [], [] + # random.seed(1337) + for _ in range(bs): + sel = random.randint(0, len(X) - 1) + xs.append(X[sel]) + ys.append(Y[sel]) + return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True - losses = [] - test_losses = [] - test_loss = float('inf') - for i in (t:=trange(2000)): - x,y = get_minibatch(X,V,bs=256) - out = net(x) - #loss = (out-y).square().mean() - loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean() - optim.zero_grad() - loss.backward() - optim.step() - t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}") - losses.append(loss.numpy().item()) - test_losses.append(test_loss) - if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item() + Tensor.no_grad, Tensor.training = False, True + losses = [] + test_losses = [] + test_loss = float("inf") + for i in (t := trange(2000)): + x, y = get_minibatch(X, V, bs=256) + out = net(x) + # loss = (out-y).square().mean() + loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean() + optim.zero_grad() + loss.backward() + optim.step() + t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}") + losses.append(loss.numpy().item()) + test_losses.append(test_loss) + if i % 10: + test_loss = (net(X_test)[:, 0:1] - V_test).square().mean().numpy().item() - safe_save(get_state_dict(net), "/tmp/qnet.safetensors") + safe_save(get_state_dict(net), "/tmp/qnet.safetensors") - import matplotlib.pyplot as plt - plt.plot(losses[20:]) - plt.plot(test_losses[20:]) - plt.show() + import matplotlib.pyplot as plt + + plt.plot(losses[20:]) + plt.plot(test_losses[20:]) + plt.show() diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py index 493959454..15be76cf4 100644 --- a/extra/optimization/get_action_space.py +++ b/extra/optimization/get_action_space.py @@ -5,30 +5,33 @@ from tinygrad.features.search import actions from tinygrad.codegen.linearizer import Linearizer tactions = set() -def test_rebuild(lin): - linr = Linearizer(lin.ast) - for o in lin.applied_opts: - assert o in actions, f"{o} is not in actions" - tactions.add(o) - linr.apply_opt(o) - assert len(lin.sts) == len(linr.sts) - for st1,st2 in zip(lin.sts, linr.sts): - assert st1 == st2, f"{st1} != {st2}" + +def test_rebuild(lin): + linr = Linearizer(lin.ast) + for o in lin.applied_opts: + assert o in actions, f"{o} is not in actions" + tactions.add(o) + linr.apply_opt(o) + + assert len(lin.sts) == len(linr.sts) + for st1, st2 in zip(lin.sts, linr.sts): + assert st1 == st2, f"{st1} != {st2}" + if __name__ == "__main__": - ast_strs = load_worlds(False, False, False) - random.shuffle(ast_strs) - ast_strs = ast_strs[:2000] - for ast_str in tqdm(ast_strs): - lin = ast_str_to_lin(ast_str) - #if not lin.apply_tensor_cores(): - lin.hand_coded_optimizations() - test_rebuild(lin) - # confirm linearize can be called twice - uops1 = lin.linearize().uops - uops2 = lin.linearize().uops - assert tuple(uops1) == tuple(uops2), f"uops mismatch {lin.colored_shape()}" + ast_strs = load_worlds(False, False, False) + random.shuffle(ast_strs) + ast_strs = ast_strs[:2000] + for ast_str in tqdm(ast_strs): + lin = ast_str_to_lin(ast_str) + # if not lin.apply_tensor_cores(): + lin.hand_coded_optimizations() + test_rebuild(lin) + # confirm linearize can be called twice + uops1 = lin.linearize().uops + uops2 = lin.linearize().uops + assert tuple(uops1) == tuple(uops2), f"uops mismatch {lin.colored_shape()}" - print(len(tactions), len(actions)) - print(sorted(list(tactions))) + print(len(tactions), len(actions)) + print(sorted(list(tactions))) diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 06bdcfb09..8f7136718 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -1,10 +1,20 @@ # stuff needed to unpack a kernel -from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.ops import ( + LazyOp, + TernaryOps, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, +) from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") # HACK: it used to be called MEM setattr(BufferOps, "MEM", BufferOps.LOAD) @@ -13,30 +23,44 @@ setattr(UnaryOps, "NOOP", UnaryOps.NEG) # kernel unpacker from tinygrad.codegen.linearizer import Linearizer -def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str) -def ast_str_to_lin(ast_str:str): - # HACK: it used to not have stores - from test.test_linearizer_failures import helper_add_store - return Linearizer(helper_add_store(ast_str_to_ast(ast_str))) + + +def ast_str_to_ast(ast_str: str) -> LazyOp: + return eval(ast_str) + + +def ast_str_to_lin(ast_str: str): + # HACK: it used to not have stores + from test.test_linearizer_failures import helper_add_store + + return Linearizer(helper_add_store(ast_str_to_ast(ast_str))) + # load worlds, a dataset of about 12k kernels import gzip from pathlib import Path import random from tinygrad.helpers import dedup + + def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): - fn = Path(__file__).parent.parent / "datasets/sops.gz" - ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) - if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x] - if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] - if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x] - random.seed(1337) - random.shuffle(ast_strs) - return ast_strs + fn = Path(__file__).parent.parent / "datasets/sops.gz" + ast_strs = dedup(gzip.open(fn).read().decode("utf-8").strip().split("\n")) + if filter_reduce: + ast_strs = [x for x in ast_strs if "ReduceOps" in x] + if filter_noimage: + ast_strs = [x for x in ast_strs if "dtypes.image" not in x] + if filter_novariable: + ast_strs = [x for x in ast_strs if "Variable" not in x] + random.seed(1337) + random.shuffle(ast_strs) + return ast_strs + def assert_same_lin(l1, l2): - assert l1.colored_shape() == l2.colored_shape() - assert all(x==y for x,y in zip(l1.sts, l2.sts)) + assert l1.colored_shape() == l2.colored_shape() + assert all(x == y for x, y in zip(l1.sts, l2.sts)) + # get features import math @@ -44,57 +68,71 @@ from tinygrad.shape.symbolic import Node MAX_DIMS = 16 MAX_BUFS = 9 -def lin_to_feats(lin:Linearizer, use_sts=True): - assert lin.shape_len < MAX_DIMS, "too many dims" - all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"] - lc = [all_colors.index(x) for x in lin.colors()] - ret = [] - # before, some generic linearizer stuff - ret.append(lin.upcasted) - ret.append(lin.local_dims) +def lin_to_feats(lin: Linearizer, use_sts=True): + assert lin.shape_len < MAX_DIMS, "too many dims" - # first, the full shape, including the colors - for s,os,c in zip(lin.full_shape,lin.output_shape,lc): - if isinstance(s, Node): - ret.append(False) - ret += [0]*9 + all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"] + lc = [all_colors.index(x) for x in lin.colors()] + + ret = [] + # before, some generic linearizer stuff + ret.append(lin.upcasted) + ret.append(lin.local_dims) + + # first, the full shape, including the colors + for s, os, c in zip(lin.full_shape, lin.output_shape, lc): + if isinstance(s, Node): + ret.append(False) + ret += [0] * 9 + else: + ret.append(True) + ret.append(math.log2(s)) + ret.append(min(33, s)) + ret.append(math.log2(os)) + ret.append(min(33, os)) + ret.append(s % 2 == 0) + ret.append(s % 3 == 0) + ret.append(s % 4 == 0) + ret.append(s % 8 == 0) + ret.append(s % 16 == 0) + cc = [0] * 7 + cc[c] = 1 + ret += cc + ret += [0] * (17 * (MAX_DIMS - len(lin.full_shape))) + ret = [float(x) for x in ret] + + if use_sts: + my_sts = dedup( + [ + ( + x.shape == lin.full_shape, + x.real_strides(), + any(v.mask is not None for v in x.views), + len(x.views), + ) + for x in lin.sts + ] + ) + assert len(my_sts) < MAX_BUFS + sts_len = 3 + 5 * MAX_DIMS + for s in my_sts: + ret.append(s[0]) # reduce + ret.append(s[2]) # has mask + ret.append(s[3]) # len views + for d in s[1]: + ret.append(d is None) + ret.append(d == 0) + ret.append(d == 1) + ret.append(min(33, d) if d is not None else -1) + if d is not None and d >= 1: + ret.append(math.log2(d)) + else: + ret.append(-1) + ret += [0] * (5 * (MAX_DIMS - len(s[1]))) + ret += [0] * (sts_len * (MAX_BUFS - len(my_sts))) + assert len(ret) == 1021, f"wrong len {len(ret)}" else: - ret.append(True) - ret.append(math.log2(s)) - ret.append(min(33, s)) - ret.append(math.log2(os)) - ret.append(min(33, os)) - ret.append(s%2 == 0) - ret.append(s%3 == 0) - ret.append(s%4 == 0) - ret.append(s%8 == 0) - ret.append(s%16 == 0) - cc = [0]*7 - cc[c] = 1 - ret += cc - ret += [0] * (17*(MAX_DIMS-len(lin.full_shape))) - ret = [float(x) for x in ret] - - if use_sts: - my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts]) - assert len(my_sts) < MAX_BUFS - sts_len = 3 + 5*MAX_DIMS - for s in my_sts: - ret.append(s[0]) # reduce - ret.append(s[2]) # has mask - ret.append(s[3]) # len views - for d in s[1]: - ret.append(d is None) - ret.append(d == 0) - ret.append(d == 1) - ret.append(min(33, d) if d is not None else -1) - if d is not None and d >= 1: ret.append(math.log2(d)) - else: ret.append(-1) - ret += [0] * (5*(MAX_DIMS - len(s[1]))) - ret += [0] * (sts_len*(MAX_BUFS - len(my_sts))) - assert len(ret) == 1021, f"wrong len {len(ret)}" - else: - assert len(ret) == 274, f"wrong len {len(ret)}" - return ret \ No newline at end of file + assert len(ret) == 274, f"wrong len {len(ret)}" + return ret diff --git a/extra/optimization/pretrain_valuenet.py b/extra/optimization/pretrain_valuenet.py index dd850def9..16256a6cc 100644 --- a/extra/optimization/pretrain_valuenet.py +++ b/extra/optimization/pretrain_valuenet.py @@ -5,84 +5,109 @@ import random from tinygrad.tensor import Tensor from tinygrad.nn import Linear from tinygrad.nn.optim import Adam -from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict +from tinygrad.nn.state import ( + get_parameters, + get_state_dict, + safe_save, + safe_load, + load_state_dict, +) # stuff needed to unpack a kernel -from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.ops import ( + LazyOp, + TernaryOps, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, +) from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") from tinygrad.codegen.kernel import Opt, OptOps from extra.optimization.helpers import lin_to_feats, MAX_DIMS # NOTE: this is not real value of the state, it's just a prediction of the runtime INNER = 512 + + class ValueNet: - def __init__(self, feats=240, out=1): - self.l1 = Linear(feats,INNER) - self.l2 = Linear(INNER,INNER) - self.l3 = Linear(INNER,INNER) - self.l4 = Linear(INNER,out) - def __call__(self, x): - x = self.l1(x).relu() - x = self.l2(x).relu() - x = self.l3(x).relu().dropout(0.8) - return self.l4(x) + def __init__(self, feats=240, out=1): + self.l1 = Linear(feats, INNER) + self.l2 = Linear(INNER, INNER) + self.l3 = Linear(INNER, INNER) + self.l4 = Linear(INNER, out) + + def __call__(self, x): + x = self.l1(x).relu() + x = self.l2(x).relu() + x = self.l3(x).relu().dropout(0.8) + return self.l4(x) + if __name__ == "__main__": - net = ValueNet() - optim = Adam(get_parameters(net)) + net = ValueNet() + optim = Adam(get_parameters(net)) - TEST_SIZE = 256 + TEST_SIZE = 256 - dset = open("/tmp/logtm").read().strip().split("\n") - random.seed(1337) - random.shuffle(dset) + dset = open("/tmp/logtm").read().strip().split("\n") + random.seed(1337) + random.shuffle(dset) - X,Y = [], [] - for i,x in enumerate(tqdm(dset)): - ast, opts, tms = eval(x) - lin = Linearizer(ast) - for o in opts: lin.apply_opt(o) - if lin.shape_len >= MAX_DIMS: continue - if min(tms) == float('inf'): continue - X.append(lin_to_feats(lin)) - Y.append([math.log(min(tms))]) - print(f"got {len(X)} samples") + X, Y = [], [] + for i, x in enumerate(tqdm(dset)): + ast, opts, tms = eval(x) + lin = Linearizer(ast) + for o in opts: + lin.apply_opt(o) + if lin.shape_len >= MAX_DIMS: + continue + if min(tms) == float("inf"): + continue + X.append(lin_to_feats(lin)) + Y.append([math.log(min(tms))]) + print(f"got {len(X)} samples") - X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:]) - X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE] + X_test, Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:]) + X, Y = X[:-TEST_SIZE], Y[:-TEST_SIZE] - def get_minibatch(X,Y,bs): - xs, ys = [], [] - for _ in range(bs): - sel = random.randint(0, len(X)-1) - xs.append(X[sel]) - ys.append(Y[sel]) - return Tensor(xs), Tensor(ys) + def get_minibatch(X, Y, bs): + xs, ys = [], [] + for _ in range(bs): + sel = random.randint(0, len(X) - 1) + xs.append(X[sel]) + ys.append(Y[sel]) + return Tensor(xs), Tensor(ys) - Tensor.no_grad, Tensor.training = False, True - losses = [] - test_losses = [] - test_loss = float('inf') - for i in (t:=trange(2000)): - x,y = get_minibatch(X,Y,bs=256) - out = net(x) - loss = (out-y).square().mean() - optim.zero_grad() - loss.backward() - optim.step() - t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}") - losses.append(loss.numpy().item()) - test_losses.append(test_loss) - if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item() + Tensor.no_grad, Tensor.training = False, True + losses = [] + test_losses = [] + test_loss = float("inf") + for i in (t := trange(2000)): + x, y = get_minibatch(X, Y, bs=256) + out = net(x) + loss = (out - y).square().mean() + optim.zero_grad() + loss.backward() + optim.step() + t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}") + losses.append(loss.numpy().item()) + test_losses.append(test_loss) + if i % 10: + test_loss = (net(X_test) - Y_test).square().mean().numpy().item() - safe_save(get_state_dict(net), "/tmp/valuenet.safetensors") + safe_save(get_state_dict(net), "/tmp/valuenet.safetensors") - import matplotlib.pyplot as plt - plt.plot(losses[200:]) - plt.plot(test_losses[200:]) - plt.show() + import matplotlib.pyplot as plt + + plt.plot(losses[200:]) + plt.plot(test_losses[200:]) + plt.show() diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index a088ae337..0ce9e574a 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -2,75 +2,92 @@ import os import numpy as np import math, random from tinygrad.tensor import Tensor -from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.features.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions +from tinygrad.nn.state import ( + get_parameters, + get_state_dict, + safe_save, + safe_load, + load_state_dict, +) +from tinygrad.features.search import ( + actions, + bufs_from_lin, + time_linearizer, + get_linearizer_actions, +) from tinygrad.nn.optim import Adam from extra.optimization.extract_policynet import PolicyNet from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats if __name__ == "__main__": - net = PolicyNet() - if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors")) - optim = Adam(get_parameters(net)) + net = PolicyNet() + if os.path.isfile("/tmp/policynet.safetensors"): + load_state_dict(net, safe_load("/tmp/policynet.safetensors")) + optim = Adam(get_parameters(net)) - ast_strs = load_worlds() + ast_strs = load_worlds() - # select a world - all_feats, all_acts, all_rews = [], [], [] - while 1: - Tensor.no_grad, Tensor.training = True, False - lin = ast_str_to_lin(random.choice(ast_strs)) - rawbufs = bufs_from_lin(lin) - tm = last_tm = base_tm = time_linearizer(lin, rawbufs) - - # take actions - feats, acts, rews = [], [], [] + # select a world + all_feats, all_acts, all_rews = [], [], [] while 1: - feat = lin_to_feats(lin) - feats.append(feat) - probs = net(Tensor([feat])).exp()[0].numpy() + Tensor.no_grad, Tensor.training = True, False + lin = ast_str_to_lin(random.choice(ast_strs)) + rawbufs = bufs_from_lin(lin) + tm = last_tm = base_tm = time_linearizer(lin, rawbufs) - # mask valid actions - valid_action_mask = np.zeros((len(actions)+1), dtype=np.float32) - for x in get_linearizer_actions(lin): valid_action_mask[x] = 1 - probs *= valid_action_mask - probs /= sum(probs) + # take actions + feats, acts, rews = [], [], [] + while 1: + feat = lin_to_feats(lin) + feats.append(feat) + probs = net(Tensor([feat])).exp()[0].numpy() - act = np.random.choice(len(probs), p=probs) - acts.append(act) - if act == 0: - rews.append(0) - break - try: - lin.apply_opt(actions[act-1]) - tm = time_linearizer(lin, rawbufs) - if math.isinf(tm): raise Exception("failed") - rews.append(((last_tm-tm)/base_tm)) - last_tm = tm - except Exception: - rews.append(-0.5) - break - #print(f"{tm*1e6:10.2f}", lin.colored_shape()) + # mask valid actions + valid_action_mask = np.zeros((len(actions) + 1), dtype=np.float32) + for x in get_linearizer_actions(lin): + valid_action_mask[x] = 1 + probs *= valid_action_mask + probs /= sum(probs) - assert len(feats) == len(acts) and len(acts) == len(rews) - #print(rews) - print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}") - all_feats += feats - all_acts += acts - # rewards to go - for i in range(len(rews)-2, -1, -1): rews[i] += rews[i+1] - all_rews += rews + act = np.random.choice(len(probs), p=probs) + acts.append(act) + if act == 0: + rews.append(0) + break + try: + lin.apply_opt(actions[act - 1]) + tm = time_linearizer(lin, rawbufs) + if math.isinf(tm): + raise Exception("failed") + rews.append(((last_tm - tm) / base_tm)) + last_tm = tm + except Exception: + rews.append(-0.5) + break + # print(f"{tm*1e6:10.2f}", lin.colored_shape()) - BS = 32 - if len(all_feats) >= BS: - Tensor.no_grad, Tensor.training = False, True - x = Tensor(all_feats[:BS]) - mask = np.zeros((BS, len(actions)+1), dtype=np.float32) - mask[range(BS), all_acts[:BS]] = all_rews[:BS] - loss = -(net(x) * Tensor(mask)).mean() - optim.zero_grad() - loss.backward() - optim.step() - all_feats = all_feats[BS:] - all_acts = all_acts[BS:] - all_rews = all_rews[BS:] + assert len(feats) == len(acts) and len(acts) == len(rews) + # print(rews) + print( + f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}" + ) + all_feats += feats + all_acts += acts + # rewards to go + for i in range(len(rews) - 2, -1, -1): + rews[i] += rews[i + 1] + all_rews += rews + + BS = 32 + if len(all_feats) >= BS: + Tensor.no_grad, Tensor.training = False, True + x = Tensor(all_feats[:BS]) + mask = np.zeros((BS, len(actions) + 1), dtype=np.float32) + mask[range(BS), all_acts[:BS]] = all_rews[:BS] + loss = -(net(x) * Tensor(mask)).mean() + optim.zero_grad() + loss.backward() + optim.step() + all_feats = all_feats[BS:] + all_acts = all_acts[BS:] + all_rews = all_rews[BS:] diff --git a/extra/optimization/run_qnet.py b/extra/optimization/run_qnet.py index f456f8adf..b0eb5ef13 100644 --- a/extra/optimization/run_qnet.py +++ b/extra/optimization/run_qnet.py @@ -3,30 +3,36 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import get_linearizer_actions, actions _net = None -def beam_q_estimate(beam:List[Tuple[Linearizer, float]]) -> List[Tuple[Linearizer, float]]: - global _net - if _net is None: - from tinygrad.nn.state import load_state_dict, safe_load - from extra.optimization.pretrain_valuenet import ValueNet - _net = ValueNet(1021+len(actions), 2) - load_state_dict(_net, safe_load("/tmp/qnet.safetensors"), verbose=False) - from tinygrad.tensor import Tensor - from tinygrad.helpers import Context - from extra.optimization.helpers import lin_to_feats - import numpy as np - feats = [] - lins = [] - base_tms = [] - for lin,tm in beam: - lin_feats = lin_to_feats(lin) - for a,v in get_linearizer_actions(lin, include_0=False).items(): - acts = np.zeros(len(actions)) - acts[a-1] = 1.0 - feats.append(np.concatenate([lin_feats, acts])) - lins.append(v) - base_tms.append(tm) - with Context(BEAM=0): - with Tensor.train(False): - preds = _net(Tensor(feats)).numpy() - pred_time = np.array(base_tms) / np.exp(preds[:, 0]) - return sorted(zip(lins, pred_time), key=lambda x: x[1]) + + +def beam_q_estimate( + beam: List[Tuple[Linearizer, float]] +) -> List[Tuple[Linearizer, float]]: + global _net + if _net is None: + from tinygrad.nn.state import load_state_dict, safe_load + from extra.optimization.pretrain_valuenet import ValueNet + + _net = ValueNet(1021 + len(actions), 2) + load_state_dict(_net, safe_load("/tmp/qnet.safetensors"), verbose=False) + from tinygrad.tensor import Tensor + from tinygrad.helpers import Context + from extra.optimization.helpers import lin_to_feats + import numpy as np + + feats = [] + lins = [] + base_tms = [] + for lin, tm in beam: + lin_feats = lin_to_feats(lin) + for a, v in get_linearizer_actions(lin, include_0=False).items(): + acts = np.zeros(len(actions)) + acts[a - 1] = 1.0 + feats.append(np.concatenate([lin_feats, acts])) + lins.append(v) + base_tms.append(tm) + with Context(BEAM=0): + with Tensor.train(False): + preds = _net(Tensor(feats)).numpy() + pred_time = np.array(base_tms) / np.exp(preds[:, 0]) + return sorted(zip(lins, pred_time), key=lambda x: x[1]) diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 6d54eef21..570d32c6b 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -6,48 +6,55 @@ from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d + class TestBeamSearch(unittest.TestCase): - def setUp(self): - self.old_beam = BEAM.value - BEAM.value = 2 - def tearDown(self): - BEAM.value = self.old_beam + def setUp(self): + self.old_beam = BEAM.value + BEAM.value = 2 - def test_variable_ast_beam(self): - a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) - a = (a+1).realize() + def tearDown(self): + BEAM.value = self.old_beam - def test_big_prime_number(self): - a = Tensor.rand(367, 367) - b = Tensor.rand(367, 367) - c = (a@b).realize() - np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) + def test_variable_ast_beam(self): + a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) + a = (a + 1).realize() - def test_variable_big_prime_number(self): - v = Variable("v", 1, 400).bind(367) - a = Tensor.rand(367, 367) - b = Tensor.rand(367, 367) - c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() - np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) + def test_big_prime_number(self): + a = Tensor.rand(367, 367) + b = Tensor.rand(367, 367) + c = (a @ b).realize() + np.testing.assert_allclose( + c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4 + ) - def test_variable_shrink_prime_number(self): - v = Variable("v", 1, 400).bind(367) - a = Tensor.rand(400, 367) - b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() - np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) + def test_variable_big_prime_number(self): + v = Variable("v", 1, 400).bind(367) + a = Tensor.rand(367, 367) + b = Tensor.rand(367, 367) + c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() + np.testing.assert_allclose( + c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4 + ) - def test_no_mutate_rawbuffers(self): - a = Tensor.rand(3, 3).realize() - desired = a.numpy() + 1 - a.assign(a+1) - actual = a.numpy() - np.testing.assert_allclose(actual, desired) + def test_variable_shrink_prime_number(self): + v = Variable("v", 1, 400).bind(367) + a = Tensor.rand(400, 367) + b = (a.shrink(((0, v), None)) + 1).reshape(367, 367).realize() + np.testing.assert_allclose(b.numpy(), a.numpy()[:367] + 1, atol=1e-4, rtol=1e-4) - def test_conv_beam(self): - c = Conv2d(3, 16, (3,3)) - x = Tensor.rand(1,3,32,32) - with Timing(): - c(x).realize() + def test_no_mutate_rawbuffers(self): + a = Tensor.rand(3, 3).realize() + desired = a.numpy() + 1 + a.assign(a + 1) + actual = a.numpy() + np.testing.assert_allclose(actual, desired) -if __name__ == '__main__': - unittest.main() + def test_conv_beam(self): + c = Conv2d(3, 16, (3, 3)) + x = Tensor.rand(1, 3, 32, 32) + with Timing(): + c(x).realize() + + +if __name__ == "__main__": + unittest.main() diff --git a/extra/optimization/test_net.py b/extra/optimization/test_net.py index 851c2a85b..7c72d78bb 100644 --- a/extra/optimization/test_net.py +++ b/extra/optimization/test_net.py @@ -1,12 +1,24 @@ import numpy as np import math import random + np.set_printoptions(suppress=True) from copy import deepcopy from tinygrad.helpers import getenv, colored from tinygrad.tensor import Tensor -from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.features.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions +from tinygrad.nn.state import ( + get_parameters, + get_state_dict, + safe_save, + safe_load, + load_state_dict, +) +from tinygrad.features.search import ( + bufs_from_lin, + time_linearizer, + actions, + get_linearizer_actions, +) from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats from extra.optimization.extract_policynet import PolicyNet from extra.optimization.pretrain_valuenet import ValueNet @@ -14,53 +26,56 @@ from extra.optimization.pretrain_valuenet import ValueNet VALUE = getenv("VALUE") if __name__ == "__main__": - if VALUE: - net = ValueNet() - load_state_dict(net, safe_load("/tmp/valuenet.safetensors")) - else: - net = PolicyNet() - load_state_dict(net, safe_load("/tmp/policynet.safetensors")) + if VALUE: + net = ValueNet() + load_state_dict(net, safe_load("/tmp/valuenet.safetensors")) + else: + net = PolicyNet() + load_state_dict(net, safe_load("/tmp/policynet.safetensors")) - ast_strs = load_worlds() + ast_strs = load_worlds() - # real randomness - random.seed() - random.shuffle(ast_strs) + # real randomness + random.seed() + random.shuffle(ast_strs) - wins = 0 - for ep_num,ast_str in enumerate(ast_strs): - print("\nEPISODE", ep_num, f"win {wins*100/max(1,ep_num):.2f}%") - lin = ast_str_to_lin(ast_str) - rawbufs = bufs_from_lin(lin) + wins = 0 + for ep_num, ast_str in enumerate(ast_strs): + print("\nEPISODE", ep_num, f"win {wins*100/max(1,ep_num):.2f}%") + lin = ast_str_to_lin(ast_str) + rawbufs = bufs_from_lin(lin) - linhc = deepcopy(lin) - linhc.hand_coded_optimizations() - tmhc = time_linearizer(linhc, rawbufs) - print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape()) + linhc = deepcopy(lin) + linhc.hand_coded_optimizations() + tmhc = time_linearizer(linhc, rawbufs) + print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape()) - pred_time = float('nan') - tm = float('inf') - while 1: - if VALUE: - acts,feats = [], [] - for k,v in get_linearizer_actions(lin).items(): - acts.append(k) - feats.append(lin_to_feats(v)) - preds = net(Tensor(feats)) - pred_time = math.exp(preds.numpy().min()) - act = acts[preds.numpy().argmin()] - else: - probs = net(Tensor([lin_to_feats(lin)])) - dist = probs.exp().numpy() - act = dist.argmax() - if act == 0: break - try: - lin.apply_opt(actions[act-1]) - except Exception: - print("FAILED") - break - tm = time_linearizer(lin, rawbufs) - print(f"{tm*1e6:10.2f} {pred_time*1e6:10.2f}", lin.colored_shape()) + pred_time = float("nan") + tm = float("inf") + while 1: + if VALUE: + acts, feats = [], [] + for k, v in get_linearizer_actions(lin).items(): + acts.append(k) + feats.append(lin_to_feats(v)) + preds = net(Tensor(feats)) + pred_time = math.exp(preds.numpy().min()) + act = acts[preds.numpy().argmin()] + else: + probs = net(Tensor([lin_to_feats(lin)])) + dist = probs.exp().numpy() + act = dist.argmax() + if act == 0: + break + try: + lin.apply_opt(actions[act - 1]) + except Exception: + print("FAILED") + break + tm = time_linearizer(lin, rawbufs) + print(f"{tm*1e6:10.2f} {pred_time*1e6:10.2f}", lin.colored_shape()) - print(f"{colored('BEAT', 'green') if tm < tmhc else colored('lost', 'red')} hand coded {tmhc/tm:5.2f}x") - wins += int(tm < tmhc) \ No newline at end of file + print( + f"{colored('BEAT', 'green') if tm < tmhc else colored('lost', 'red')} hand coded {tmhc/tm:5.2f}x" + ) + wins += int(tm < tmhc) diff --git a/extra/optimization/test_time_linearizer.py b/extra/optimization/test_time_linearizer.py index 47af6c52b..6bbbe5536 100644 --- a/extra/optimization/test_time_linearizer.py +++ b/extra/optimization/test_time_linearizer.py @@ -1,21 +1,31 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin -from tinygrad.features.search import bufs_from_lin, time_linearizer, get_linearizer_actions +from tinygrad.features.search import ( + bufs_from_lin, + time_linearizer, + get_linearizer_actions, +) if __name__ == "__main__": - ast_strs = load_worlds() - for i, ast_str in enumerate(ast_strs): - lin = ast_str_to_lin(ast_str) - rawbufs = bufs_from_lin(lin) - test_tm = time_linearizer(lin, rawbufs) - if test_tm < 1e-2: continue - print(f"EXAMPLE {i}") - acted_lins = get_linearizer_actions(lin) - ok_avg, short_avg = 0, 0 - for k,v in acted_lins.items(): - tm1 = time_linearizer(v, rawbufs) - tm2 = time_linearizer(v, rawbufs) - tm3 = time_linearizer(v, rawbufs, False) - print(v.colored_shape(50), f"{tm1*1e3:10.2f} {tm2*1e3:10.2f} {tm3*1e3:10.2f} : {((tm1-tm2)/tm1)*100:5.2f}% vs {((tm1-tm3)/tm1)*100:5.2f}%") - ok_avg += (tm1-tm2)/tm1 - short_avg += (tm1-tm3)/tm1 - print(f"{ok_avg/len(acted_lins)*100:5.2f}% vs {short_avg/len(acted_lins)*100:5.2f}%") + ast_strs = load_worlds() + for i, ast_str in enumerate(ast_strs): + lin = ast_str_to_lin(ast_str) + rawbufs = bufs_from_lin(lin) + test_tm = time_linearizer(lin, rawbufs) + if test_tm < 1e-2: + continue + print(f"EXAMPLE {i}") + acted_lins = get_linearizer_actions(lin) + ok_avg, short_avg = 0, 0 + for k, v in acted_lins.items(): + tm1 = time_linearizer(v, rawbufs) + tm2 = time_linearizer(v, rawbufs) + tm3 = time_linearizer(v, rawbufs, False) + print( + v.colored_shape(50), + f"{tm1*1e3:10.2f} {tm2*1e3:10.2f} {tm3*1e3:10.2f} : {((tm1-tm2)/tm1)*100:5.2f}% vs {((tm1-tm3)/tm1)*100:5.2f}%", + ) + ok_avg += (tm1 - tm2) / tm1 + short_avg += (tm1 - tm3) / tm1 + print( + f"{ok_avg/len(acted_lins)*100:5.2f}% vs {short_avg/len(acted_lins)*100:5.2f}%" + ) diff --git a/extra/thneed.py b/extra/thneed.py index c59f63685..addf17735 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -10,197 +10,267 @@ from tinygrad.helpers import DEBUG, getenv from collections import defaultdict import pyopencl as cl from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO + CL = Device["GPU"] DEBUGCL = getenv("DEBUGCL", 0) FLOAT16 = getenv("FLOAT16", 0) + class Thneed: - def __init__(self, cl_cache=[], inputs={}): - self.cl_cache, self.inputs = cl_cache[:], inputs - self.gobj = 0 + def __init__(self, cl_cache=[], inputs={}): + self.cl_cache, self.inputs = cl_cache[:], inputs + self.gobj = 0 - # build graph - # NOTE: if CLCACHE=1, this is wrong! - nodes = defaultdict(lambda: {'in_edges': [], 'out_edges': []}) - for _, args in self.cl_cache: - # output is always the first parameter - for a in args[3:]: - nodes[a]['out_edges'].append(args[2]) - nodes[args[2]]['in_edges'].append(a) + # build graph + # NOTE: if CLCACHE=1, this is wrong! + nodes = defaultdict(lambda: {"in_edges": [], "out_edges": []}) + for _, args in self.cl_cache: + # output is always the first parameter + for a in args[3:]: + nodes[a]["out_edges"].append(args[2]) + nodes[args[2]]["in_edges"].append(a) - # get buffers to save - self.buffers_to_save = set() - self.outputs = [] - for n in nodes.keys(): - if len(nodes[n]['in_edges']) == 0: - self.buffers_to_save.add(n) - if len(nodes[n]['out_edges']) == 0: - self.outputs.append(n) + # get buffers to save + self.buffers_to_save = set() + self.outputs = [] + for n in nodes.keys(): + if len(nodes[n]["in_edges"]) == 0: + self.buffers_to_save.add(n) + if len(nodes[n]["out_edges"]) == 0: + self.outputs.append(n) - fake_inputs = [] - for k,n in self.inputs.items(): - if n in self.buffers_to_save: - self.buffers_to_save.remove(n) - else: - print(f"WARNING: {k} was not a used input, removing it") - fake_inputs.append(k) - for k in fake_inputs: - del self.inputs[k] - - def load(self, input_fn): - float32 = not FLOAT16 - - mf = cl.mem_flags - image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT) - image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT) - - with open(input_fn, "rb") as f: - json_len = struct.unpack("I", f.read(4))[0] - jdat = json.loads(f.read(json_len).decode('latin_1')) - weights = f.read() - - # load in the buffers - bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None} - bufs_loaded = {} - ptr = 0 - for o in jdat['objects']: - #print(o) - if o['needs_load']: - nptr = ptr + o['size'] - o['data'] = weights[ptr:nptr] - ptr = nptr - - if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t": - tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt - if o['arg_type'] == "image2d_t": - if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]: - # hack: use a image1d since we can back that with a buffer - buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']]) - else: - # buffer isn't supported in image2d, copy buffer into image - if 'buffer_id' in o and bufs_loaded[o['buffer_id']]: - arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16) - cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']]) - buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt, - shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr) - elif o['needs_load']: - buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt, - shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data']) + fake_inputs = [] + for k, n in self.inputs.items(): + if n in self.buffers_to_save: + self.buffers_to_save.remove(n) else: - buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height'])) - if o['arg_type'] == "image1d_t": - assert not o['needs_load'] - assert not bufs_loaded[o['buffer_id']] - buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']]) - else: - if 'data' in o: - buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data']) - else: - # zero out buffers - buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size']) + print(f"WARNING: {k} was not a used input, removing it") + fake_inputs.append(k) + for k in fake_inputs: + del self.inputs[k] - bufs[o['id']] = buf - bufs_loaded[o['id']] = 'data' in o - # if it's loaded, it's saved - if 'data' in o: - self.buffers_to_save.add(buf) + def load(self, input_fn): + float32 = not FLOAT16 - # load binaries - prgs = {} - for o in jdat['binaries']: - nptr = ptr + o['length'] - prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr]) - ptr = nptr + mf = cl.mem_flags + image_fmt = cl.ImageFormat( + cl.channel_order.RGBA, + cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT, + ) + image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT) - # populate the cl_cache - for i,k in enumerate(jdat['kernels']): - kernel = prgs[k['name']] - aaa = [] - for j,(a,sz) in enumerate(zip(k['args'], k['args_size'])): - if len(a) == 0: - aa = cl.LocalMemory(sz) - elif len(a) == 4: - a = a.encode('latin_1') - aa = np.uint32(struct.unpack("I", a)[0]) - elif len(a) == 2: - a = a.encode('latin_1') - aa = np.uint16(struct.unpack("H", a)[0]) - elif len(a) == 8: - #print(i,j,struct.unpack("Q", a.encode('latin_1'))[0]) - aa = bufs[a] - aaa.append(aa) - self.cl_cache.append((kernel, [k['global_work_size'], k['local_work_size'], *aaa])) + with open(input_fn, "rb") as f: + json_len = struct.unpack("I", f.read(4))[0] + jdat = json.loads(f.read(json_len).decode("latin_1")) + weights = f.read() - if DEBUG >= 1: print(f"thneed: total bufs loaded: {len(bufs.keys())}") + # load in the buffers + bufs = {"\x00\x00\x00\x00\x00\x00\x00\x00": None} + bufs_loaded = {} + ptr = 0 + for o in jdat["objects"]: + # print(o) + if o["needs_load"]: + nptr = ptr + o["size"] + o["data"] = weights[ptr:nptr] + ptr = nptr - # load inputs - for k in jdat['inputs']: - self.inputs[k['name']] = bufs[k['buffer_id']] + if o["arg_type"] == "image2d_t" or o["arg_type"] == "image1d_t": + tfmt = image_fmt_32 if "float32" in o and o["float32"] else image_fmt + if o["arg_type"] == "image2d_t": + if ( + "buffer_id" in o + and o["height"] == 1 + and not bufs_loaded[o["buffer_id"]] + ): + # hack: use a image1d since we can back that with a buffer + buf = cl.Image( + CL.ctx, + mf.READ_WRITE, + tfmt, + shape=(o["width"],), + buffer=bufs[o["buffer_id"]], + ) + else: + # buffer isn't supported in image2d, copy buffer into image + if "buffer_id" in o and bufs_loaded[o["buffer_id"]]: + arr = np.zeros( + bufs[o["buffer_id"]].size // 2, dtype=np.float16 + ) + cl.enqueue_copy(CL.queue, arr, bufs[o["buffer_id"]]) + buf = cl.Image( + CL.ctx, + mf.READ_WRITE | mf.COPY_HOST_PTR, + tfmt, + shape=(o["width"], o["height"]), + pitches=(o["row_pitch"],), + hostbuf=arr, + ) + elif o["needs_load"]: + buf = cl.Image( + CL.ctx, + mf.READ_WRITE | mf.COPY_HOST_PTR, + tfmt, + shape=(o["width"], o["height"]), + pitches=(o["row_pitch"],), + hostbuf=o["data"], + ) + else: + buf = cl.Image( + CL.ctx, + mf.READ_WRITE, + tfmt, + shape=(o["width"], o["height"]), + ) + if o["arg_type"] == "image1d_t": + assert not o["needs_load"] + assert not bufs_loaded[o["buffer_id"]] + buf = cl.Image( + CL.ctx, + mf.READ_WRITE, + tfmt, + shape=(o["width"],), + buffer=bufs[o["buffer_id"]], + ) + else: + if "data" in o: + buf = cl.Buffer( + CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o["data"] + ) + else: + # zero out buffers + buf = cl.Buffer( + CL.ctx, + mf.READ_WRITE | mf.COPY_HOST_PTR, + hostbuf=b"\x00" * o["size"], + ) - # load outputs - for k in jdat['outputs']: - self.outputs.append(bufs[k['buffer_id']]) + bufs[o["id"]] = buf + bufs_loaded[o["id"]] = "data" in o + # if it's loaded, it's saved + if "data" in o: + self.buffers_to_save.add(buf) + # load binaries + prgs = {} + for o in jdat["binaries"]: + nptr = ptr + o["length"] + prgs[o["name"]] = CLProgram(Device["GPU"], o["name"], weights[ptr:nptr]) + ptr = nptr - def save(self, output_fn): - # this is the struct that will be saved - jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []} + # populate the cl_cache + for i, k in enumerate(jdat["kernels"]): + kernel = prgs[k["name"]] + aaa = [] + for j, (a, sz) in enumerate(zip(k["args"], k["args_size"])): + if len(a) == 0: + aa = cl.LocalMemory(sz) + elif len(a) == 4: + a = a.encode("latin_1") + aa = np.uint32(struct.unpack("I", a)[0]) + elif len(a) == 2: + a = a.encode("latin_1") + aa = np.uint16(struct.unpack("H", a)[0]) + elif len(a) == 8: + # print(i,j,struct.unpack("Q", a.encode('latin_1'))[0]) + aa = bufs[a] + aaa.append(aa) + self.cl_cache.append( + (kernel, [k["global_work_size"], k["local_work_size"], *aaa]) + ) - # build the pieces of this struct - weights = [] - binaries = [] - saved_objs = set() - saved_binaries = set() - for prg, args in self.cl_cache: - # get binaries for saving - if prg.name not in saved_binaries: - binary = prg.clprogram.get_info(cl.program_info.BINARIES) - assert len(binary) == 1 - jdat['binaries'].append({"name":prg.name, "length":len(binary[0])}) - binaries.append(binary[0]) - saved_binaries.add(prg.name) + if DEBUG >= 1: + print(f"thneed: total bufs loaded: {len(bufs.keys())}") - # get the args from the kernel, some need the data saved - targs, args_size = [], [] - argdtypes = [None]*(len(args)-2) - for a,d in zip(args[2:], argdtypes): - if d == np.int16: - targs.append(struct.pack("H", a).decode("latin_1")) - args_size.append(2) - elif d == np.int32: - targs.append(struct.pack("I", a).decode("latin_1")) - args_size.append(4) - elif isinstance(a, cl.LocalMemory): - targs.append("") - args_size.append(a.size) - elif d is None: - if getattr(a, "global_id", None) is None: - setattr(a, "global_id", self.gobj) - self.gobj += 1 - ptr = struct.pack("Q", a.global_id).decode("latin_1") - if ptr not in saved_objs: - if isinstance(a, cl.Buffer): - needs_load = a in self.buffers_to_save - jdat['objects'].append({ - "id": ptr, "arg_type": "float*", "needs_load": needs_load, "size": a.size, - }) - if needs_load: - data = np.empty(a.size//4, dtype=np.float32) - cl.enqueue_copy(CL.queue, data, a, is_blocking=True) - weights.append(data.tobytes()) - elif isinstance(a, cl.Image): - assert a.format == cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT), "wrong type" - needs_load = a in self.buffers_to_save - row_pitch = (a.shape[0]*4*(2 if FLOAT16 else 4) + 63)//64 * 64 - size = row_pitch * a.shape[1] - # this is *2 if float16 and *4 if float32 - buf = cl.Buffer(CL.ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1)) + # load inputs + for k in jdat["inputs"]: + self.inputs[k["name"]] = bufs[k["buffer_id"]] - # zero out the buffer - cl.enqueue_copy(CL.queue, buf, b'\x00'*buf.size, is_blocking=True) + # load outputs + for k in jdat["outputs"]: + self.outputs.append(bufs[k["buffer_id"]]) - CLProgram(CL, "from_image_strided", compile_gpu(""" + def save(self, output_fn): + # this is the struct that will be saved + jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []} + + # build the pieces of this struct + weights = [] + binaries = [] + saved_objs = set() + saved_binaries = set() + for prg, args in self.cl_cache: + # get binaries for saving + if prg.name not in saved_binaries: + binary = prg.clprogram.get_info(cl.program_info.BINARIES) + assert len(binary) == 1 + jdat["binaries"].append({"name": prg.name, "length": len(binary[0])}) + binaries.append(binary[0]) + saved_binaries.add(prg.name) + + # get the args from the kernel, some need the data saved + targs, args_size = [], [] + argdtypes = [None] * (len(args) - 2) + for a, d in zip(args[2:], argdtypes): + if d == np.int16: + targs.append(struct.pack("H", a).decode("latin_1")) + args_size.append(2) + elif d == np.int32: + targs.append(struct.pack("I", a).decode("latin_1")) + args_size.append(4) + elif isinstance(a, cl.LocalMemory): + targs.append("") + args_size.append(a.size) + elif d is None: + if getattr(a, "global_id", None) is None: + setattr(a, "global_id", self.gobj) + self.gobj += 1 + ptr = struct.pack("Q", a.global_id).decode("latin_1") + if ptr not in saved_objs: + if isinstance(a, cl.Buffer): + needs_load = a in self.buffers_to_save + jdat["objects"].append( + { + "id": ptr, + "arg_type": "float*", + "needs_load": needs_load, + "size": a.size, + } + ) + if needs_load: + data = np.empty(a.size // 4, dtype=np.float32) + cl.enqueue_copy(CL.queue, data, a, is_blocking=True) + weights.append(data.tobytes()) + elif isinstance(a, cl.Image): + assert a.format == cl.ImageFormat( + cl.channel_order.RGBA, + cl.channel_type.HALF_FLOAT + if FLOAT16 + else cl.channel_type.FLOAT, + ), "wrong type" + needs_load = a in self.buffers_to_save + row_pitch = ( + (a.shape[0] * 4 * (2 if FLOAT16 else 4) + 63) // 64 * 64 + ) + size = row_pitch * a.shape[1] + # this is *2 if float16 and *4 if float32 + buf = cl.Buffer( + CL.ctx, + cl.mem_flags.READ_WRITE, + size=size * (2 if FLOAT16 else 1), + ) + + # zero out the buffer + cl.enqueue_copy( + CL.queue, buf, b"\x00" * buf.size, is_blocking=True + ) + + CLProgram( + CL, + "from_image_strided", + compile_gpu( + """ __kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) { const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; int2 l; @@ -208,80 +278,128 @@ class Thneed: l.x = get_global_id(0); out[l.y*row_pitch + l.x] = read_imagef(in, smp, l); } - """), bufs=2, vars=1)(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape) + """ + ), + bufs=2, + vars=1, + )( + a, + buf, + row_pitch // (4 * (2 if FLOAT16 else 4)), + global_size=a.shape, + ) - # multiple of 32 isn't enough - jdat['objects'].append({ - "id": ptr, "needs_load": needs_load, "size": size, "arg_type": "image2d_t", - "width": a.shape[0], "height": a.shape[1], "row_pitch": row_pitch, "float32": not FLOAT16, - }) + # multiple of 32 isn't enough + jdat["objects"].append( + { + "id": ptr, + "needs_load": needs_load, + "size": size, + "arg_type": "image2d_t", + "width": a.shape[0], + "height": a.shape[1], + "row_pitch": row_pitch, + "float32": not FLOAT16, + } + ) - if needs_load: - data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32) - cl.enqueue_copy(CL.queue, data, buf, is_blocking=True) - if FLOAT16: data = data.astype(np.float16) - weights.append(data.tobytes()) - else: - raise Exception("unknown object", a) - #print(jdat['objects'][-1]) - saved_objs.add(ptr) - targs.append(ptr) - args_size.append(8) - else: - raise Exception("idk this type") + if needs_load: + data = np.empty( + size // (2 if FLOAT16 else 4), dtype=np.float32 + ) + cl.enqueue_copy(CL.queue, data, buf, is_blocking=True) + if FLOAT16: + data = data.astype(np.float16) + weights.append(data.tobytes()) + else: + raise Exception("unknown object", a) + # print(jdat['objects'][-1]) + saved_objs.add(ptr) + targs.append(ptr) + args_size.append(8) + else: + raise Exception("idk this type") - # save the kernel itself - jdat['kernels'].append({ - "name": prg.name, - "work_dim": len(args[0]), - "global_work_size": args[0], - # TODO: C++ thneed requires a local_work_size, so we fill it with ones - "local_work_size": [1 for _ in args[0]] if args[1] is None else args[1], - "num_args": len(args)-2, - "args": targs, - "args_size": args_size - }) + # save the kernel itself + jdat["kernels"].append( + { + "name": prg.name, + "work_dim": len(args[0]), + "global_work_size": args[0], + # TODO: C++ thneed requires a local_work_size, so we fill it with ones + "local_work_size": [1 for _ in args[0]] + if args[1] is None + else args[1], + "num_args": len(args) - 2, + "args": targs, + "args_size": args_size, + } + ) - jdat['outputs'] = [{ - "buffer_id": struct.pack("Q", x.global_id).decode("latin_1"), - "size": x.size, - } for x in self.outputs] + jdat["outputs"] = [ + { + "buffer_id": struct.pack("Q", x.global_id).decode("latin_1"), + "size": x.size, + } + for x in self.outputs + ] - jdat['inputs'] = [{ - "buffer_id": struct.pack("Q", v.global_id).decode("latin_1"), - "size": v.size, - "name": k - } for k,v in self.inputs.items()][::-1] + jdat["inputs"] = [ + { + "buffer_id": struct.pack("Q", v.global_id).decode("latin_1"), + "size": v.size, + "name": k, + } + for k, v in self.inputs.items() + ][::-1] - print(f"saving thneed to {output_fn}") - with open(output_fn, "wb") as f: - j = json.dumps(jdat, ensure_ascii=False).encode('latin_1') - f.write(struct.pack("I", len(j))) - f.write(j) - f.write(b''.join(weights)) - f.write(b''.join(binaries)) + print(f"saving thneed to {output_fn}") + with open(output_fn, "wb") as f: + j = json.dumps(jdat, ensure_ascii=False).encode("latin_1") + f.write(struct.pack("I", len(j))) + f.write(j) + f.write(b"".join(weights)) + f.write(b"".join(binaries)) - def run(self): - events = [] - st = time.monotonic() - for prg, args in self.cl_cache: - events.append(prg.clprg(CL.queue, *args)) - mt = time.monotonic() - Device["GPU"].synchronize() - et = time.monotonic() - st - print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms") + def run(self): + events = [] + st = time.monotonic() + for prg, args in self.cl_cache: + events.append(prg.clprg(CL.queue, *args)) + mt = time.monotonic() + Device["GPU"].synchronize() + et = time.monotonic() - st + print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms") - if DEBUGCL >= 2: - for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): - print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end])) - if DEBUGCL >= 1: - total_runtime = 0 - for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): - runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO - print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}") - if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3): - print(prg.prg) - total_runtime += runtime - print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms") - return total_runtime/1e9 - return et + if DEBUGCL >= 2: + for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): + print( + f"{i:3d} {prg.name:25s} " + + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" + % tuple( + (x * OSX_TIMING_RATIO - st * 1e9) / 1e6 + for x in [ + e.profile.queued, + e.profile.submit, + e.profile.start, + e.profile.end, + ] + ) + ) + if DEBUGCL >= 1: + total_runtime = 0 + for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): + runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO + print( + f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}" + ) + if hasattr(prg, "prg") and ( + (DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3 + ): + print(prg.prg) + total_runtime += runtime + print( + f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms" + ) + return total_runtime / 1e9 + return et diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index 69560179a..b02bf03b3 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -7,126 +7,280 @@ from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sym_infer, Node + # ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216) def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]: - to_apply:List[Tuple[MovementOps, Tuple]] = [] - for i, v in enumerate(st.views): - real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) - real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) - real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] - strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] - buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1 - if i: buffer_size = prod(st.views[i-1].shape) - real_offset - def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) - ordered_shape_strides, order = sort_by_strides(real_real_shape, strides) - to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))]) - if strides: - if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),))) - for i, shape_stride in enumerate(ordered_shape_strides): - if i0 else buffer_size - to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer))) - to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1))) - to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer))) - to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1])))) - to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1]))) - ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1]) - else: - to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1])))) - to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1]))) - to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))]) - if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides))))) - to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides)))) - if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides))) - # then, we apply pre expand pads - if v.mask is not None: - pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - if any(x != (0,0) for x in pre_expand_pads): - to_apply.append((MovementOps.PAD, pre_expand_pads)) - real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) - # then, we do any expands - if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) - # lastly, we apply post expand pads - if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) - return to_apply + to_apply: List[Tuple[MovementOps, Tuple]] = [] + for i, v in enumerate(st.views): + real_shape = tuple(y - x for x, y in v.mask) if v.mask else v.shape + offset = v.offset + sum( + st * (s - 1) for s, st in zip(real_shape, v.strides) if st < 0 + ) + real_offset = offset + ( + sum(x * st for (x, _), st in zip(v.mask, v.strides)) if v.mask else 0 + ) + real_real_shape = [s for s, st in zip(real_shape, v.strides) if st] + strides: List[Node | int] = [ + abs(st) if isinstance(st, int) else st for st in v.strides if st + ] + buffer_size = sum((s - 1) * st for s, st in zip(real_real_shape, strides)) + 1 + if i: + buffer_size = prod(st.views[i - 1].shape) - real_offset + + def sort_by_strides(shape, strides): + return sorted( + zip(shape, strides), key=lambda k: (k[1], -k[0]), reverse=True + ), sorted( + range(len(strides)), + key=lambda k: (strides[k], -real_real_shape[k]), + reverse=True, + ) + + ordered_shape_strides, order = sort_by_strides(real_real_shape, strides) + to_apply.extend( + [ + (MovementOps.RESHAPE, (-1,)), + (MovementOps.SHRINK, ((real_offset, real_offset + buffer_size),)), + ] + ) + if strides: + if ( + ordered_shape_strides[0][0] * ordered_shape_strides[0][1] + ) - buffer_size > 0: + to_apply.append( + ( + MovementOps.PAD, + ( + ( + 0, + ( + ordered_shape_strides[0][0] + * ordered_shape_strides[0][1] + ) + - buffer_size, + ), + ), + ) + ) + for i, shape_stride in enumerate(ordered_shape_strides): + if ( + i < len(ordered_shape_strides) - 1 + and shape_stride[1] + < ordered_shape_strides[i + 1][0] * ordered_shape_strides[i + 1][1] + ): + remaining_buffer = ( + ordered_shape_strides[i - 1][1] if i > 0 else buffer_size + ) + to_apply.append( + ( + MovementOps.EXPAND, + ( + shape_stride[0], + *(s[0] for s in ordered_shape_strides[:i]), + remaining_buffer, + ), + ) + ) + to_apply.append((MovementOps.PERMUTE, (*range(1, i + 1), 0, i + 1))) + to_apply.append( + ( + MovementOps.RESHAPE, + ( + *(s[0] for s in ordered_shape_strides[:i]), + shape_stride[0] * remaining_buffer, + ), + ) + ) + to_apply.append( + ( + MovementOps.PAD, + ( + *((0, 0) for _ in range(i)), + (0, shape_stride[0] * shape_stride[1]), + ), + ) + ) + to_apply.append( + ( + MovementOps.RESHAPE, + ( + *(s[0] for s in ordered_shape_strides[: i + 1]), + remaining_buffer + shape_stride[1], + ), + ) + ) + ordered_shape_strides[i] = ( + ordered_shape_strides[i][0], + remaining_buffer + shape_stride[1], + ) + else: + to_apply.append( + ( + MovementOps.SHRINK, + ( + *((0, s[0]) for s in ordered_shape_strides[:i]), + (0, shape_stride[0] * shape_stride[1]), + ), + ) + ) + to_apply.append( + ( + MovementOps.RESHAPE, + ( + *[s[0] for s in ordered_shape_strides[: i + 1]], + shape_stride[1], + ), + ) + ) + to_apply.extend( + [ + ( + MovementOps.SHRINK, + (*[(0, s[0]) for s in ordered_shape_strides], (0, 1)), + ), + (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides)), + ] + ) + if order != list(range(len(order))): + to_apply.append( + ( + MovementOps.PERMUTE, + tuple(order.index(i) for i in range(len(strides))), + ) + ) + to_apply.append( + ( + MovementOps.RESHAPE, + tuple(s if st else 1 for s, st in zip(real_shape, v.strides)), + ) + ) + if any(i < 0 for i in v.strides): + to_apply.append( + (MovementOps.STRIDE, tuple(-1 if st < 0 else 1 for st in v.strides)) + ) + # then, we apply pre expand pads + if v.mask is not None: + pre_expand_pads = tuple( + (x, s - y) if st != 0 else (0, 0) + for (x, y), s, st in zip(v.mask, v.shape, v.strides) + ) + post_expand_pads = tuple( + (x, s - y) if st == 0 else (0, 0) + for (x, y), s, st in zip(v.mask, v.shape, v.strides) + ) + if any(x != (0, 0) for x in pre_expand_pads): + to_apply.append((MovementOps.PAD, pre_expand_pads)) + real_shape = tuple( + x + s[0] + s[1] for x, s in zip(real_shape, pre_expand_pads) + ) + # then, we do any expands + if any(s != 1 and st == 0 for s, st in zip(real_shape, v.strides)): + to_apply.append((MovementOps.EXPAND, real_shape)) + # lastly, we apply post expand pads + if v.mask is not None and any(x != (0, 0) for x in post_expand_pads): + to_apply.append((MovementOps.PAD, post_expand_pads)) + return to_apply + def get_real_view(shape, strides, offset, mask): - real_shape = tuple(y-x for x,y in mask) if mask else shape - offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0) - real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0) - real_real_shape = [s for s,st in zip(real_shape, strides) if st] - strides = [abs(st) if isinstance(st,int) else st for st in strides if st] - return real_real_shape, strides, real_offset + real_shape = tuple(y - x for x, y in mask) if mask else shape + offset = offset + sum(st * (s - 1) for s, st in zip(real_shape, strides) if st < 0) + real_offset = offset + ( + sum(x * st for (x, _), st in zip(mask, strides)) if mask else 0 + ) + real_real_shape = [s for s, st in zip(real_shape, strides) if st] + strides = [abs(st) if isinstance(st, int) else st for st in strides if st] + return real_real_shape, strides, real_offset + def get_buffer_size(shape, strides, offset, mask): - real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask) - return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1 + real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask) + return ( + real_offset + sum((s - 1) * st for s, st in zip(real_real_shape, strides)) + 1 + ) + def st_equivalent(st1: ShapeTracker, st2: ShapeTracker): - if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True - idx1, valid1 = idxs1 - idx2, valid2 = idxs2 - # always invalid - if valid1 == 0 and valid2 == 0: return True + if (idxs1 := st1.expr_idxs()) == (idxs2 := st2.expr_idxs()): + return True + idx1, valid1 = idxs1 + idx2, valid2 = idxs2 + # always invalid + if valid1 == 0 and valid2 == 0: + return True - var1 = idx1.vars() | valid1.vars() - var2 = idx2.vars() | valid2.vars() - # Maybe there are cases that vars are different yet the sts are the same? - if var1 != var2: return False + var1 = idx1.vars() | valid1.vars() + var2 = idx2.vars() | valid2.vars() + # Maybe there are cases that vars are different yet the sts are the same? + if var1 != var2: + return False - # brute force over the vars range - vs = list(var1) - for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])): - if i > 1000: - print("WARNING: did not search all possible combinations") - # not happening for now - break - var_vals = {k:v for k,v in zip(vs, ranges)} - r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0 - r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0 - if r1 != r2: return False + # brute force over the vars range + vs = list(var1) + for i, ranges in enumerate( + itertools.product(*[range(v.min, v.max + 1) for v in vs]) + ): + if i > 1000: + print("WARNING: did not search all possible combinations") + # not happening for now + break + var_vals = {k: v for k, v in zip(vs, ranges)} + r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0 + r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0 + if r1 != r2: + return False + + return True - return True def test_rebuild(st: ShapeTracker): - rebuilt_st = ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),)) - for mop, arg in to_movement_ops(st): - if mop == MovementOps.RESHAPE: - # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE - if arg == (-1,): - rebuilt_st = rebuilt_st.reshape((prod(rebuilt_st.views[-1].shape),)) - else: - rebuilt_st = rebuilt_st.reshape(arg) - elif mop == MovementOps.PERMUTE: - rebuilt_st = rebuilt_st.permute(arg) - elif mop == MovementOps.EXPAND: - if len(arg) != len(rebuilt_st.shape): - rebuilt_st = rebuilt_st.reshape((1,*rebuilt_st.shape)) - rebuilt_st = rebuilt_st.expand(arg) - elif mop == MovementOps.PAD: - rebuilt_st = rebuilt_st.pad(arg) - elif mop == MovementOps.SHRINK: - rebuilt_st = rebuilt_st.shrink(arg) - elif mop == MovementOps.STRIDE: - rebuilt_st = rebuilt_st.stride(arg) - else: - raise Exception("invalid mop") - rebuilt_st = rebuilt_st.simplify() - assert st_equivalent(st, rebuilt_st) - last_v1 = st.views[-1] - last_v2 = rebuilt_st.views[-1] - assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" + rebuilt_st = ShapeTracker.from_shape( + ( + get_buffer_size( + st.views[0].shape, + st.views[0].strides, + st.views[0].offset, + st.views[0].mask, + ), + ) + ) + for mop, arg in to_movement_ops(st): + if mop == MovementOps.RESHAPE: + # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE + if arg == (-1,): + rebuilt_st = rebuilt_st.reshape((prod(rebuilt_st.views[-1].shape),)) + else: + rebuilt_st = rebuilt_st.reshape(arg) + elif mop == MovementOps.PERMUTE: + rebuilt_st = rebuilt_st.permute(arg) + elif mop == MovementOps.EXPAND: + if len(arg) != len(rebuilt_st.shape): + rebuilt_st = rebuilt_st.reshape((1, *rebuilt_st.shape)) + rebuilt_st = rebuilt_st.expand(arg) + elif mop == MovementOps.PAD: + rebuilt_st = rebuilt_st.pad(arg) + elif mop == MovementOps.SHRINK: + rebuilt_st = rebuilt_st.shrink(arg) + elif mop == MovementOps.STRIDE: + rebuilt_st = rebuilt_st.stride(arg) + else: + raise Exception("invalid mop") + rebuilt_st = rebuilt_st.simplify() + assert st_equivalent(st, rebuilt_st) + last_v1 = st.views[-1] + last_v2 = rebuilt_st.views[-1] + assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" -def test_interpret_ast(ast:LazyOp): - if ast.op in BufferOps: - test_rebuild(ast.arg.st) - else: - for src in ast.src: test_interpret_ast(src) + +def test_interpret_ast(ast: LazyOp): + if ast.op in BufferOps: + test_rebuild(ast.arg.st) + else: + for src in ast.src: + test_interpret_ast(src) if __name__ == "__main__": - ast_strs = load_worlds(False, False, True)[:4000] - for ast_str in tqdm(ast_strs): - test_interpret_ast(ast_str_to_ast(ast_str)) + ast_strs = load_worlds(False, False, True)[:4000] + for ast_str in tqdm(ast_strs): + test_interpret_ast(ast_str_to_ast(ast_str)) diff --git a/extra/training.py b/extra/training.py index c127ae6f3..c04436690 100644 --- a/extra/training.py +++ b/extra/training.py @@ -5,54 +5,74 @@ from tinygrad.helpers import CI from tinygrad.jit import TinyJit -def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y), - transform=lambda x: x, target_transform=lambda x: x, noloss=False): +def train( + model, + X_train, + Y_train, + optim, + steps, + BS=128, + lossfn=lambda out, y: out.sparse_categorical_crossentropy(y), + transform=lambda x: x, + target_transform=lambda x: x, + noloss=False, +): + @TinyJit + def train_step(x, y): + # network + out = model.forward(x) if hasattr(model, "forward") else model(x) + loss = lossfn(out, y) + optim.zero_grad() + loss.backward() + if noloss: + del loss + optim.step() + if noloss: + return (None, None) + cat = out.argmax(axis=-1) + accuracy = (cat == y).mean() + return loss.realize(), accuracy.realize() - @TinyJit - def train_step(x, y): - # network - out = model.forward(x) if hasattr(model, 'forward') else model(x) - loss = lossfn(out, y) - optim.zero_grad() - loss.backward() - if noloss: del loss - optim.step() - if noloss: return (None, None) - cat = out.argmax(axis=-1) - accuracy = (cat == y).mean() - return loss.realize(), accuracy.realize() - - with Tensor.train(): - losses, accuracies = [], [] - for i in (t := trange(steps, disable=CI)): - samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(transform(X_train[samp]), requires_grad=False) - y = Tensor(target_transform(Y_train[samp])) - loss, accuracy = train_step(x, y) - # printing - if not noloss: - loss, accuracy = loss.numpy(), accuracy.numpy() - losses.append(loss) - accuracies.append(accuracy) - t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) - return [losses, accuracies] + with Tensor.train(): + losses, accuracies = [], [] + for i in (t := trange(steps, disable=CI)): + samp = np.random.randint(0, X_train.shape[0], size=(BS)) + x = Tensor(transform(X_train[samp]), requires_grad=False) + y = Tensor(target_transform(Y_train[samp])) + loss, accuracy = train_step(x, y) + # printing + if not noloss: + loss, accuracy = loss.numpy(), accuracy.numpy() + losses.append(loss) + accuracies.append(accuracy) + t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) + return [losses, accuracies] -def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x, - target_transform=lambda y: y): - Tensor.training = False - def numpy_eval(Y_test, num_classes): - Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes]) - for i in trange((len(Y_test)-1)//BS+1, disable=CI): - x = Tensor(transform(X_test[i*BS:(i+1)*BS])) - out = model.forward(x) if hasattr(model, 'forward') else model(x) - Y_test_preds_out[i*BS:(i+1)*BS] = out.numpy() - Y_test_preds = np.argmax(Y_test_preds_out, axis=-1) - Y_test = target_transform(Y_test) - return (Y_test == Y_test_preds).mean(), Y_test_preds +def evaluate( + model, + X_test, + Y_test, + num_classes=None, + BS=128, + return_predict=False, + transform=lambda x: x, + target_transform=lambda y: y, +): + Tensor.training = False - if num_classes is None: num_classes = Y_test.max().astype(int)+1 - acc, Y_test_pred = numpy_eval(Y_test, num_classes) - print("test set accuracy is %f" % acc) - return (acc, Y_test_pred) if return_predict else acc + def numpy_eval(Y_test, num_classes): + Y_test_preds_out = np.zeros(list(Y_test.shape) + [num_classes]) + for i in trange((len(Y_test) - 1) // BS + 1, disable=CI): + x = Tensor(transform(X_test[i * BS : (i + 1) * BS])) + out = model.forward(x) if hasattr(model, "forward") else model(x) + Y_test_preds_out[i * BS : (i + 1) * BS] = out.numpy() + Y_test_preds = np.argmax(Y_test_preds_out, axis=-1) + Y_test = target_transform(Y_test) + return (Y_test == Y_test_preds).mean(), Y_test_preds + if num_classes is None: + num_classes = Y_test.max().astype(int) + 1 + acc, Y_test_pred = numpy_eval(Y_test, num_classes) + print("test set accuracy is %f" % acc) + return (acc, Y_test_pred) if return_predict else acc diff --git a/extra/triton/triton.py b/extra/triton/triton.py index 32453782e..53aee518e 100644 --- a/extra/triton/triton.py +++ b/extra/triton/triton.py @@ -2,130 +2,247 @@ from typing import Dict, List, Final, Callable, DefaultDict from collections import defaultdict from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv -from tinygrad.codegen.linearizer import UOp, UOps +from tinygrad.codegen.linearizer import UOp, UOps from triton.compiler import compile as triton_compile import linecache import math import re -triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"} -signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"} +triton_dtypes = { + dtypes.double: "tl.float64", + dtypes.float32: "tl.float32", + dtypes.float16: "tl.float16", + dtypes.bool: "tl.int1", + dtypes.int8: "tl.int8", + dtypes.uint8: "tl.uint8", + dtypes.int32: "tl.int32", + dtypes.int64: "tl.int64", + dtypes.uint32: "tl.uint32", + dtypes.uint64: "tl.uint64", + dtypes.int16: "tl.int16", + dtypes.uint16: "tl.uint16", +} +signature_dtypes = { + dtypes.double: "*fp64", + dtypes.float32: "*fp32", + dtypes.float16: "*fp16", + dtypes.bool: "*i8", + dtypes.int8: "*i1", + dtypes.uint8: "*u8", + dtypes._arg_int32: "i32", + dtypes.int32: "*i32", + dtypes.int64: "*i64", + dtypes.uint32: "*u32", + dtypes.uint64: "*u64", + dtypes.int16: "*i16", + dtypes.uint16: "*u16", +} + def next_power_of_2(x): - return 1 << (x - 1).bit_length() + return 1 << (x - 1).bit_length() + def render_valid(valid): - return '(' * (len(valid) -1) + ') and '.join(valid) if len(valid) else 'True' + return "(" * (len(valid) - 1) + ") and ".join(valid) if len(valid) else "True" -#NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile + +# NOTE Triton requires matching dimensions for load/store, disable this and see TestOps::test_output_padded_conv_transpose2d fail to compile def fill_dims_for_idx(idx, dims): - return "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx + return ( + "(" + idx + "+ (" + (f"0*({'+'.join(d for d in dims)})))") if len(dims) else idx + ) + def get_max(var): - if isinstance(var, int): return var - return re.sub(r'\[(.*?)\]', '', str(var))[1:-1] + if isinstance(var, int): + return var + return re.sub(r"\[(.*?)\]", "", str(var))[1:-1] -#NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved + +# NOTE can be removed after https://github.com/gpuocelot/gpuocelot/issues/8 gets resolved def remove_single_scalar_curly_braces(ptx_code): - return '\n'.join([re.sub(r'\{\s*(%\w+)\s*\}', r'\1', line) for line in ptx_code.split('\n')]) + return "\n".join( + [re.sub(r"\{\s*(%\w+)\s*\}", r"\1", line) for line in ptx_code.split("\n")] + ) -def render_const(args,dtype:DType): - return (('-' if args<0 else '') + 'tl.where(1,float("inf"),0)') if math.isinf(args) else ('tl.where(1,float("nan"),0)' if math.isnan(args) else f"{int(args)}" if dtypes.is_int(dtype) else str(args)) -def render_cast(x:str, dtype:DType): - return f"{x}.to({triton_dtypes[dtype]})" +def render_const(args, dtype: DType): + return ( + (("-" if args < 0 else "") + 'tl.where(1,float("inf"),0)') + if math.isinf(args) + else ( + 'tl.where(1,float("nan"),0)' + if math.isnan(args) + else f"{int(args)}" + if dtypes.is_int(dtype) + else str(args) + ) + ) + + +def render_cast(x: str, dtype: DType): + return f"{x}.to({triton_dtypes[dtype]})" + def define_scalar(local_size, dtype, args): - if len(local_size) > 0: return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})" - return render_const(args,dtype) + if len(local_size) > 0: + return f"tl.full(({','.join([str(next_power_of_2(x)) for x in local_size])},),{render_const(args,dtype)}, dtype={triton_dtypes[dtype]})" + return render_const(args, dtype) -def uops_to_triton(function_name:str, uops:List[UOp]): - local_size: List[int] = [] - depth = 1 - signatures, dims, bufs, kernel, valid = [], [], [], [], [] #type: ignore - c: DefaultDict[str, int] = defaultdict(int) - r: Dict[UOp, str] = {} - def ssa(u, prefix="t"): - nonlocal c, r - c[prefix] += 1 - r[u]=f"{prefix}{c[prefix]-1}" - return r[u] +def uops_to_triton(function_name: str, uops: List[UOp]): + local_size: List[int] = [] + depth = 1 + signatures, dims, bufs, kernel, valid = [], [], [], [], [] # type: ignore - child_count: DefaultDict[UOp, int] = defaultdict(int) - for ru in uops: - for v in ru.vin: - child_count[v] += 1 + c: DefaultDict[str, int] = defaultdict(int) + r: Dict[UOp, str] = {} - def kk(s): kernel.append(" "*depth+s) - code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})", - UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})", - UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})", - UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})", - UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)", - BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})", - BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))", - BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})", - BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})", - BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)", - TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})", - TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})", - } - def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))" - for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == UOps.LOOP: - kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):") - depth += 1 - elif uop == UOps.END: depth -= 1 - elif uop == UOps.ALU: - assert dtype is not None - val = code_for_op[args](*[r[x] for x in vin]) - if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val - else: kk(f"{ssa(u, 'alu')} = ({val})") - elif uop == UOps.LOAD: - assert dtype is not None - if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}") - else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}") - elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}") - elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args) - elif uop == UOps.PHI: - kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}") - r[u] = r[vin[0]] - elif uop == UOps.STORE: - assert not isinstance(dtype, ImageDType), "unimplemented: image store" - kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ") - elif uop == UOps.DEFINE_GLOBAL: - bufs.append(args) - signatures.append(signature_dtypes[args[1]]) - r[u] = args[0] - elif uop == UOps.SPECIAL: - dims.append(args[1]) - valid.append(f"{args[1]}<{get_max(args[2])}") - if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}") - elif args[1].startswith("l"): - kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") - local_size.append(args[2]) - r[u] = args[1] - elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype) - else: raise NotImplementedError(f"unimplemented: {uop}") + def ssa(u, prefix="t"): + nonlocal c, r + c[prefix] += 1 + r[u] = f"{prefix}{c[prefix]-1}" + return r[u] - prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n" - for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]" - prg += "\n".join(kernel) + child_count: DefaultDict[UOp, int] = defaultdict(int) + for ru in uops: + for v in ru.vin: + child_count[v] += 1 - acc_local_size = 1 - for x in local_size: acc_local_size *= next_power_of_2(x) - local_size = [acc_local_size] + [1] * (len(local_size) - 1) + def kk(s): + kernel.append(" " * depth + s) - if DEBUG >= 4: print(prg) - getlines = linecache.getlines - linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "" == filename else getlines(filename, module_globals) - exec(compile(prg, "", "exec"), globals()) # pylint: disable=W0122\ - compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None)) - prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0]) - max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] - for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) + code_for_op: Final[Dict[Op, Callable]] = { + UnaryOps.EXP2: lambda x, dtype,: f"tl.math.exp2({x})", + UnaryOps.LOG2: lambda x, dtype,: f"tl.math.log2({x})", + UnaryOps.SIN: lambda x, dtype: f"tl.sin({x})", + UnaryOps.SQRT: lambda x, dtype: f"tl.sqrt({x})", + UnaryOps.NEG: lambda x, dtype: f"-{x}" + if dtype != dtypes.bool + else f"tl.where({x}, 0, 1)", + BinaryOps.ADD: lambda x, y, dtype: f"({x}+{y})", + BinaryOps.SUB: lambda x, y,: f"({x}-{y})", + BinaryOps.MUL: lambda x, y, dtype: f"({x}*{y})", + BinaryOps.DIV: lambda x, y,: f"({x}/{y})" + if y != "0.0" + else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))", + BinaryOps.MAX: lambda x, y, dtype: f"tl.maximum({x},{y})", + BinaryOps.CMPLT: lambda x, y, dtype: f"({x}<{y})", + BinaryOps.MOD: lambda x, y, dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)", + TernaryOps.MULACC: lambda x, y, z, dtype: f"(({x}*{y})+{z})", + TernaryOps.WHERE: lambda x, y, z, dtype: f"tl.where({x},{y},{z})", + } - return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))} + def int_div(x, y): + return ( + f"({x}//{y})" + if y != "0" + else f"{x}*tl.where({x}==0, float('nan'), float('inf'))" + ) + + for u in uops: + uop, dtype, vin, args = u.uop, u.dtype, u.vin, u.arg + if uop == UOps.LOOP: + kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):") + depth += 1 + elif uop == UOps.END: + depth -= 1 + elif uop == UOps.ALU: + assert dtype is not None + val = code_for_op[args](*[r[x] for x in vin]) + if child_count[u] <= 1 or dtypes.is_int(dtype): + r[u] = ( + int_div(*[r[x] for x in vin]) + if args == BinaryOps.DIV and dtypes.is_int(dtype) + else val + ) + else: + kk(f"{ssa(u, 'alu')} = ({val})") + elif uop == UOps.LOAD: + assert dtype is not None + if len(vin) == 2: + kk( + f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}" + ) + else: + kk( + f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}" + ) + elif uop == UOps.DEFINE_ACC: + kk( + f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}" + ) + elif uop == UOps.CONST: + r[u] = define_scalar([], dtype, args) + elif uop == UOps.PHI: + kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}") + r[u] = r[vin[0]] + elif uop == UOps.STORE: + assert not isinstance(dtype, ImageDType), "unimplemented: image store" + kk( + f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) " + ) + elif uop == UOps.DEFINE_GLOBAL: + bufs.append(args) + signatures.append(signature_dtypes[args[1]]) + r[u] = args[0] + elif uop == UOps.SPECIAL: + dims.append(args[1]) + valid.append(f"{args[1]}<{get_max(args[2])}") + if args[1].startswith("g"): + kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}") + elif args[1].startswith("l"): + kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") + local_size.append(args[2]) + r[u] = args[1] + elif uop == UOps.CAST and dtype is not None: + r[u] = render_cast(r[vin[0]], dtype) + else: + raise NotImplementedError(f"unimplemented: {uop}") + + prg = ( + f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}(" + + ",".join(f"{buf[0]}" for buf in bufs) + + "):\n" + ) + for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): + kernel[ + kernel.index(line) + ] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]" + prg += "\n".join(kernel) + + acc_local_size = 1 + for x in local_size: + acc_local_size *= next_power_of_2(x) + local_size = [acc_local_size] + [1] * (len(local_size) - 1) + + if DEBUG >= 4: + print(prg) + getlines = linecache.getlines + linecache.getlines = ( + lambda filename, module_globals=None: prg.splitlines(keepends=True) + if "" == filename + else getlines(filename, module_globals) + ) + exec(compile(prg, "", "exec"), globals()) # pylint: disable=W0122\ + compiled = triton_compile( + globals()[function_name], + signature=",".join(signatures), + device_type="cuda", + debug=False, + cc=(35 if getenv("CUDACPU", 0) else None), + ) + prg = remove_single_scalar_curly_braces( + compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0] + ) + max_local_size = [ + int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ") + ] + for i in range(len(local_size)): + local_size[i] = min(local_size[i], max_local_size[i]) + + return prg, { + "shared": compiled.metadata["shared"], + "local_size": local_size + [1] * (3 - len(local_size)), + } diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 4e8631aa9..4382fcf21 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -1,11 +1,16 @@ #!/usr/bin/env python3 import os, sys, io, pathlib, re + sys.path.insert(0, str(pathlib.Path(__file__).parents[1])) -if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" -if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" -if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" -if "OPT" not in os.environ: os.environ["OPT"] = "99" +if "FLOAT16" not in os.environ: + os.environ["FLOAT16"] = "1" +if "IMAGE" not in os.environ: + os.environ["IMAGE"] = "2" +if "NOLOCALS" not in os.environ: + os.environ["NOLOCALS"] = "1" +if "OPT" not in os.environ: + os.environ["OPT"] = "99" OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" @@ -15,121 +20,172 @@ from typing import Tuple, List, Optional, Dict from extra.onnx import get_run_onnx from tinygrad.graph import log_schedule_item from tinygrad import Tensor, Device -from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG +from tinygrad.helpers import ( + dtypes, + partition, + GlobalCounters, + Context, + fetch, + getenv, + ImageDType, + GRAPH, + DEBUG, +) from tinygrad.realize import run_schedule, lower_schedule_item from tinygrad.ops import LoadOps, ScheduleItem + Device.DEFAULT = "GPU" + def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: - Tensor.no_grad = True - Tensor.training = False + Tensor.no_grad = True + Tensor.training = False - # load the model - onnx_model = onnx.load(io.BytesIO(onnx_data)) - run_onnx = get_run_onnx(onnx_model) - input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - - # run the model - inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()} - ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous() - schedule = ret.lazydata.schedule() - - # filter schedule that don't depend on the inputs - input_lb = [x.lazydata.base for x in inputs.values()] - depends = set(input_lb) - for si in schedule: - if any(b in depends for b in si.inputs): - depends.add(si.out) - - # run all kernels that don't depend on the inputs - # NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized - schedule, schedule_independent = partition(schedule, lambda si: si.out in depends) - print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't") - - # confirm no loadops in the (non independent) schedule except for the ones that load the input buffers - assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed" - return schedule, schedule_independent, inputs - -def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[str, Tensor]): - import onnx - #import pyopencl as cl - #from extra.thneed import Thneed - import numpy as np - onnx_model = onnx.load(io.BytesIO(onnx_data)) - - input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - Tensor.manual_seed(1337) - new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()} - new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()} - - if getenv("ORT"): - # test with onnxruntime - import onnxruntime as ort - onnx_session = ort.InferenceSession(onnx_data) - onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()}) - new_torch_out = onnx_output[0] - print("got ort outputs") - else: - # test with torch - from test.models.test_onnx import run_onnx_torch - new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() - print("got torch outputs") - - # if you don't have a schedule - if schedule is None: + # load the model + onnx_model = onnx.load(io.BytesIO(onnx_data)) run_onnx = get_run_onnx(onnx_model) - new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy() - np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2) - print("classic self-test passed!") - return + input_shapes = { + inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) + for inp in onnx_model.graph.input + } - # set inputs - for k,v in inputs.items(): v.lazydata.realized.copyin(new_np_inputs[k].data) + # run the model + inputs = {k: Tensor.empty(*shp) for k, shp in input_shapes.items()} + ret: Tensor = ( + next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous() + ) + schedule = ret.lazydata.schedule() - # run code (all buffers have been allocated) - GlobalCounters.reset() - for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {}) + # filter schedule that don't depend on the inputs + input_lb = [x.lazydata.base for x in inputs.values()] + depends = set(input_lb) + for si in schedule: + if any(b in depends for b in si.inputs): + depends.add(si.out) + + # run all kernels that don't depend on the inputs + # NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized + schedule, schedule_independent = partition(schedule, lambda si: si.out in depends) + print( + f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't" + ) + + # confirm no loadops in the (non independent) schedule except for the ones that load the input buffers + assert all( + si.ast.op not in LoadOps or si.out in input_lb for si in schedule + ), "has loadops, can't compile to Thneed" + return schedule, schedule_independent, inputs + + +def test_vs_onnx( + onnx_data, schedule: Optional[List[ScheduleItem]], inputs: Dict[str, Tensor] +): + import onnx + + # import pyopencl as cl + # from extra.thneed import Thneed + import numpy as np + + onnx_model = onnx.load(io.BytesIO(onnx_data)) + + input_shapes = { + inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) + for inp in onnx_model.graph.input + } + Tensor.manual_seed(1337) + new_inputs = { + k: Tensor.randn(*shp, requires_grad=False) * 8 + for k, shp in input_shapes.items() + } + new_np_inputs = {k: v.realize().numpy() for k, v in new_inputs.items()} + + if getenv("ORT"): + # test with onnxruntime + import onnxruntime as ort + + onnx_session = ort.InferenceSession(onnx_data) + onnx_output = onnx_session.run( + [onnx_model.graph.output[0].name], + {k: v.astype(np.float16) for k, v in new_np_inputs.items()}, + ) + new_torch_out = onnx_output[0] + print("got ort outputs") + else: + # test with torch + from test.models.test_onnx import run_onnx_torch + + new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() + print("got torch outputs") + + # if you don't have a schedule + if schedule is None: + run_onnx = get_run_onnx(onnx_model) + new_tinygrad_out = ( + next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy() + ) + np.testing.assert_allclose( + new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2 + ) + print("classic self-test passed!") + return + + # set inputs + for k, v in inputs.items(): + v.lazydata.realized.copyin(new_np_inputs[k].data) + + # run code (all buffers have been allocated) + GlobalCounters.reset() + for si in schedule: + lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {}) + + new_tinygrad_out = schedule[-1].out.realized.toCPU() + np.testing.assert_allclose( + new_torch_out.flatten(), new_tinygrad_out, atol=1e-4, rtol=1e-2 + ) + print("semi-thneed self-test passed!") - new_tinygrad_out = schedule[-1].out.realized.toCPU() - np.testing.assert_allclose(new_torch_out.flatten(), new_tinygrad_out, atol=1e-4, rtol=1e-2) - print("semi-thneed self-test passed!") if __name__ == "__main__": - onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes() + onnx_data = fetch( + sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL + ).read_bytes() - # quick test for ONNX issues - #thneed_test_onnx(onnx_data, None) - #exit(0) + # quick test for ONNX issues + # thneed_test_onnx(onnx_data, None) + # exit(0) - schedule, schedule_independent, inputs = get_schedule(onnx_data) - schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps) - print(f"{len(schedule_input)} inputs") + schedule, schedule_independent, inputs = get_schedule(onnx_data) + schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps) + print(f"{len(schedule_input)} inputs") - run_schedule(schedule_independent, disable_logging=True) - run_schedule(schedule_input) - with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")): - image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) - print(f"**** running real kernels {image_count}/{len(schedule)} images ****") + run_schedule(schedule_independent, disable_logging=True) + run_schedule(schedule_input) + with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")): + image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) + print(f"**** running real kernels {image_count}/{len(schedule)} images ****") - if GRAPH: - for si in schedule_input: log_schedule_item(si) - for si in schedule: log_schedule_item(si) + if GRAPH: + for si in schedule_input: + log_schedule_item(si) + for si in schedule: + log_schedule_item(si) - GlobalCounters.reset() - run_schedule(schedule[:]) + GlobalCounters.reset() + run_schedule(schedule[:]) - print("kernel count:", len(schedule)) - assert len(schedule) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!" - - # TODO: thneed is broken - #output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed" - #schedule_to_thneed(schedule, output_fn) - - FLOAT16 = getenv("FLOAT16", 0) - if FLOAT16 == 0: - try: - test_vs_onnx(onnx_data, schedule, inputs) - except ModuleNotFoundError as e: - print(f"TEST NOT HAPPENING {e}") + print("kernel count:", len(schedule)) + assert ( + len(schedule) <= getenv("ALLOWED_KERNEL_COUNT", 0) + or getenv("ALLOWED_KERNEL_COUNT", 0) == 0 + ), "too many kernels!" + # TODO: thneed is broken + # output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed" + # schedule_to_thneed(schedule, output_fn) + FLOAT16 = getenv("FLOAT16", 0) + if FLOAT16 == 0: + try: + test_vs_onnx(onnx_data, schedule, inputs) + except ModuleNotFoundError as e: + print(f"TEST NOT HAPPENING {e}") diff --git a/setup.py b/setup.py index 6feb3aa4f..1f6f97046 100644 --- a/setup.py +++ b/setup.py @@ -4,31 +4,44 @@ from pathlib import Path from setuptools import setup directory = Path(__file__).resolve().parent -with open(directory / 'README.md', encoding='utf-8') as f: - long_description = f.read() +with open(directory / "README.md", encoding="utf-8") as f: + long_description = f.read() -setup(name='tinygrad', - version='0.8.0', - description='You like pytorch? You like micrograd? You love tinygrad! <3', - author='George Hotz', - license='MIT', - long_description=long_description, - long_description_content_type='text/markdown', - packages = ['tinygrad', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.runtime', 'tinygrad.shape', 'tinygrad.features'], - classifiers=[ +setup( + name="tinygrad", + version="0.8.0", + description="You like pytorch? You like micrograd? You love tinygrad! <3", + author="George Hotz", + license="MIT", + long_description=long_description, + long_description_content_type="text/markdown", + packages=[ + "tinygrad", + "tinygrad.codegen", + "tinygrad.nn", + "tinygrad.renderer", + "tinygrad.runtime", + "tinygrad.shape", + "tinygrad.features", + ], + classifiers=[ "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License" - ], - install_requires=["numpy", "tqdm", "gpuctypes", - "pyobjc-framework-Metal; platform_system=='Darwin'", - "pyobjc-framework-libdispatch; platform_system=='Darwin'"], - python_requires='>=3.8', - extras_require={ - 'llvm': ["llvmlite"], - 'arm': ["unicorn"], - 'triton': ["triton-nightly>=2.1.0.dev20231014192330"], - 'webgpu': ["wgpu>=v0.12.0"], - 'linting': [ + "License :: OSI Approved :: MIT License", + ], + install_requires=[ + "numpy", + "tqdm", + "gpuctypes", + "pyobjc-framework-Metal; platform_system=='Darwin'", + "pyobjc-framework-libdispatch; platform_system=='Darwin'", + ], + python_requires=">=3.8", + extras_require={ + "llvm": ["llvmlite"], + "arm": ["unicorn"], + "triton": ["triton-nightly>=2.1.0.dev20231014192330"], + "webgpu": ["wgpu>=v0.12.0"], + "linting": [ "pylint", "mypy", "typing-extensions", @@ -36,7 +49,7 @@ setup(name='tinygrad', "ruff", "types-tqdm", ], - 'testing': [ + "testing": [ "torch", "pillow", "pytest", @@ -52,6 +65,7 @@ setup(name='tinygrad', "tiktoken", "librosa", "networkx", - ] - }, - include_package_data=True) + ], + }, + include_package_data=True, +) diff --git a/sz.py b/sz.py index b64de6564..3cf37f7f5 100755 --- a/sz.py +++ b/sz.py @@ -7,62 +7,117 @@ from tabulate import tabulate TOKEN_WHITELIST = [token.OP, token.NAME, token.NUMBER, token.STRING] + def gen_stats(base_path="."): - table = [] - for path, _, files in os.walk(os.path.join(base_path, "tinygrad")): - for name in files: - if not name.endswith(".py"): continue - filepath = os.path.join(path, name) - relfilepath = os.path.relpath(filepath, base_path) - with tokenize.open(filepath) as file_: - tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST] - token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens])) - table.append([relfilepath, line_count, token_count/line_count]) - return table + table = [] + for path, _, files in os.walk(os.path.join(base_path, "tinygrad")): + for name in files: + if not name.endswith(".py"): + continue + filepath = os.path.join(path, name) + relfilepath = os.path.relpath(filepath, base_path) + with tokenize.open(filepath) as file_: + tokens = [ + t + for t in tokenize.generate_tokens(file_.readline) + if t.type in TOKEN_WHITELIST + ] + token_count, line_count = len(tokens), len( + set([t.start[0] for t in tokens]) + ) + table.append([relfilepath, line_count, token_count / line_count]) + return table + def gen_diff(table_old, table_new): - table = [] - files_new = set([x[0] for x in table_new]) - files_old = set([x[0] for x in table_old]) - added, deleted, unchanged = files_new - files_old, files_old - files_new, files_new & files_old - if added: - for file in added: - file_stat = [stats for stats in table_new if file in stats] - table.append([file_stat[0][0], file_stat[0][1], file_stat[0][1]-0, file_stat[0][2], file_stat[0][2]-0]) - if deleted: - for file in deleted: - file_stat = [stats for stats in table_old if file in stats] - table.append([file_stat[0][0], 0, 0 - file_stat[0][1], 0, 0-file_stat[0][2]]) - if unchanged: - for file in unchanged: - file_stat_old = [stats for stats in table_old if file in stats] - file_stat_new = [stats for stats in table_new if file in stats] - if file_stat_new[0][1]-file_stat_old[0][1] != 0 or file_stat_new[0][2]-file_stat_old[0][2] != 0: - table.append([file_stat_new[0][0], file_stat_new[0][1], file_stat_new[0][1]-file_stat_old[0][1], file_stat_new[0][2], file_stat_new[0][2]-file_stat_old[0][2]]) - return table + table = [] + files_new = set([x[0] for x in table_new]) + files_old = set([x[0] for x in table_old]) + added, deleted, unchanged = ( + files_new - files_old, + files_old - files_new, + files_new & files_old, + ) + if added: + for file in added: + file_stat = [stats for stats in table_new if file in stats] + table.append( + [ + file_stat[0][0], + file_stat[0][1], + file_stat[0][1] - 0, + file_stat[0][2], + file_stat[0][2] - 0, + ] + ) + if deleted: + for file in deleted: + file_stat = [stats for stats in table_old if file in stats] + table.append( + [file_stat[0][0], 0, 0 - file_stat[0][1], 0, 0 - file_stat[0][2]] + ) + if unchanged: + for file in unchanged: + file_stat_old = [stats for stats in table_old if file in stats] + file_stat_new = [stats for stats in table_new if file in stats] + if ( + file_stat_new[0][1] - file_stat_old[0][1] != 0 + or file_stat_new[0][2] - file_stat_old[0][2] != 0 + ): + table.append( + [ + file_stat_new[0][0], + file_stat_new[0][1], + file_stat_new[0][1] - file_stat_old[0][1], + file_stat_new[0][2], + file_stat_new[0][2] - file_stat_old[0][2], + ] + ) + return table + + +def display_diff(diff): + return "+" + str(diff) if diff > 0 else str(diff) -def display_diff(diff): return "+"+str(diff) if diff > 0 else str(diff) if __name__ == "__main__": - if len(sys.argv) == 3: - headers = ["Name", "Lines", "Diff", "Tokens/Line", "Diff"] - table = gen_diff(gen_stats(sys.argv[1]), gen_stats(sys.argv[2])) - elif len(sys.argv) == 2: - headers = ["Name", "Lines", "Tokens/Line"] - table = gen_stats(sys.argv[1]) - else: - headers = ["Name", "Lines", "Tokens/Line"] - table = gen_stats(".") - - if table: if len(sys.argv) == 3: - print("### Changes") - print("```") - print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", intfmt=(..., "d", "+d"), floatfmt=(..., ..., ..., ".1f", "+.1f"))+"\n") - print(f"\ntotal lines changes: {display_diff(sum([x[2] for x in table]))}") - print("```") + headers = ["Name", "Lines", "Diff", "Tokens/Line", "Diff"] + table = gen_diff(gen_stats(sys.argv[1]), gen_stats(sys.argv[2])) + elif len(sys.argv) == 2: + headers = ["Name", "Lines", "Tokens/Line"] + table = gen_stats(sys.argv[1]) else: - print(tabulate([headers] + sorted(table, key=lambda x: -x[1]), headers="firstrow", floatfmt=".1f")+"\n") - for dir_name, group in itertools.groupby(sorted([(x[0].rsplit("/", 1)[0], x[1], x[2]) for x in table]), key=lambda x:x[0]): - print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}") - print(f"\ntotal line count: {sum([x[1] for x in table])}") + headers = ["Name", "Lines", "Tokens/Line"] + table = gen_stats(".") + + if table: + if len(sys.argv) == 3: + print("### Changes") + print("```") + print( + tabulate( + [headers] + sorted(table, key=lambda x: -x[1]), + headers="firstrow", + intfmt=(..., "d", "+d"), + floatfmt=(..., ..., ..., ".1f", "+.1f"), + ) + + "\n" + ) + print(f"\ntotal lines changes: {display_diff(sum([x[2] for x in table]))}") + print("```") + else: + print( + tabulate( + [headers] + sorted(table, key=lambda x: -x[1]), + headers="firstrow", + floatfmt=".1f", + ) + + "\n" + ) + for dir_name, group in itertools.groupby( + sorted([(x[0].rsplit("/", 1)[0], x[1], x[2]) for x in table]), + key=lambda x: x[0], + ): + print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}") + print(f"\ntotal line count: {sum([x[1] for x in table])}") diff --git a/test/external/dist/test_collectives.py b/test/external/dist/test_collectives.py index 843bc4947..9f821b272 100644 --- a/test/external/dist/test_collectives.py +++ b/test/external/dist/test_collectives.py @@ -1,61 +1,82 @@ from extra import dist from tinygrad.jit import TinyJit + if __name__ == "__main__": - dist.preinit() + dist.preinit() from extra.dist import collectives from tinygrad.helpers import CI, getenv from tinygrad.tensor import Tensor import numpy as np + @TinyJit -def allreduce_jit(t:Tensor) -> Tensor: - return collectives.allreduce(t).realize() +def allreduce_jit(t: Tensor) -> Tensor: + return collectives.allreduce(t).realize() + SIZE = 2048 if not CI else 2 SIZE_2 = 255 if not CI else 3 + def run(): - # set a deterministic seed so that both ranks generate the same random tensor - Tensor.manual_seed(42) + # set a deterministic seed so that both ranks generate the same random tensor + Tensor.manual_seed(42) - rank = getenv("RANK") + rank = getenv("RANK") - # loop 3 times to make sure it works with the jit - for _ in range(3): - # create a tensor to send - t = Tensor.zeros(SIZE, SIZE) if rank != 0 else Tensor.ones(SIZE, SIZE) - t2 = allreduce_jit(t.contiguous().realize()) - assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy()), f"{t2.numpy()} wasn't ones" + # loop 3 times to make sure it works with the jit + for _ in range(3): + # create a tensor to send + t = Tensor.zeros(SIZE, SIZE) if rank != 0 else Tensor.ones(SIZE, SIZE) + t2 = allreduce_jit(t.contiguous().realize()) + assert np.allclose( + np.ones((SIZE, SIZE)), t2.numpy() + ), f"{t2.numpy()} wasn't ones" - # reset jit - allreduce_jit.cnt = 0 + # reset jit + allreduce_jit.cnt = 0 - # test uneven chunk sizes - for _ in range(3): - # create a tensor to send - t = Tensor.ones(SIZE_2, SIZE_2, SIZE_2) if rank == 0 else Tensor.zeros(SIZE_2, SIZE_2, SIZE_2) - t2 = allreduce_jit(t.contiguous().realize()) - assert np.allclose(np.ones((SIZE_2, SIZE_2, SIZE_2)), t2.numpy()), f"{t2.numpy()} wasn't ones" + # test uneven chunk sizes + for _ in range(3): + # create a tensor to send + t = ( + Tensor.ones(SIZE_2, SIZE_2, SIZE_2) + if rank == 0 + else Tensor.zeros(SIZE_2, SIZE_2, SIZE_2) + ) + t2 = allreduce_jit(t.contiguous().realize()) + assert np.allclose( + np.ones((SIZE_2, SIZE_2, SIZE_2)), t2.numpy() + ), f"{t2.numpy()} wasn't ones" + + print(f"rank {rank} passed") - print(f"rank {rank} passed") if __name__ == "__main__": - if getenv("HIP"): - from tinygrad.runtime.ops_hip import HIP - devices = [f"hip:{i}" for i in range(HIP.device_count)] - else: - from tinygrad.runtime.ops_gpu import CL - devices = [f"gpu:{i}" for i in range(len(CL.devices))] if not CI else ["gpu:0", "gpu:0"] - world_size = len(devices) + if getenv("HIP"): + from tinygrad.runtime.ops_hip import HIP - dist.init_oob(world_size) + devices = [f"hip:{i}" for i in range(HIP.device_count)] + else: + from tinygrad.runtime.ops_gpu import CL - processes = [] - for rank, device in enumerate(devices): - processes.append(dist.spawn(rank, device, fn=run, args=())) - for p in processes: p.join() + devices = ( + [f"gpu:{i}" for i in range(len(CL.devices))] + if not CI + else ["gpu:0", "gpu:0"] + ) + world_size = len(devices) - # exit with error code if any of the processes failed - for p in processes: - if p.exitcode != 0: exit(p.exitcode) + dist.init_oob(world_size) + + processes = [] + for rank, device in enumerate(devices): + processes.append(dist.spawn(rank, device, fn=run, args=())) + for p in processes: + p.join() + + # exit with error code if any of the processes failed + for p in processes: + if p.exitcode != 0: + exit(p.exitcode) diff --git a/test/external/dist/test_world.py b/test/external/dist/test_world.py index de0525da4..a5a4633fe 100644 --- a/test/external/dist/test_world.py +++ b/test/external/dist/test_world.py @@ -1,68 +1,78 @@ from extra import dist from tinygrad.jit import TinyJit + if __name__ == "__main__": - dist.preinit() + dist.preinit() from extra.dist import world from tinygrad.helpers import CI, getenv from tinygrad.tensor import Tensor import numpy as np + @TinyJit def send_jit(t, target_rank) -> Tensor: - return world.send(t, target_rank).realize() + return world.send(t, target_rank).realize() + @TinyJit def recv_jit(t, target_rank) -> Tensor: - return world.recv(t, target_rank).realize() + return world.recv(t, target_rank).realize() + SIZE = 2048 if not CI else 2 + def run(): - # set a deterministic seed so that both ranks generate the same random tensor - Tensor.manual_seed(42) + # set a deterministic seed so that both ranks generate the same random tensor + Tensor.manual_seed(42) - rank = getenv("RANK") + rank = getenv("RANK") - # loop 3 times to make sure it works with the jit - for _ in range(3): - # create a tensor to send - t = Tensor.randn(SIZE, SIZE) + # loop 3 times to make sure it works with the jit + for _ in range(3): + # create a tensor to send + t = Tensor.randn(SIZE, SIZE) - # send to rank 1 - if rank == 0: - send_jit(t, 1) - elif rank == 1: - t2 = Tensor.empty(SIZE, SIZE) - recv_jit(t2, 0) + # send to rank 1 + if rank == 0: + send_jit(t, 1) + elif rank == 1: + t2 = Tensor.empty(SIZE, SIZE) + recv_jit(t2, 0) - # recv from rank 1 - if rank == 0: - t2 = Tensor.empty(SIZE, SIZE) - recv_jit(t2, 1) - elif rank == 1: - send_jit(t2, 0) + # recv from rank 1 + if rank == 0: + t2 = Tensor.empty(SIZE, SIZE) + recv_jit(t2, 1) + elif rank == 1: + send_jit(t2, 0) - # check that the received tensor is the same as the sent tensor - if rank == 0: - assert np.allclose(t.numpy(), t2.numpy()), f"{t2.numpy()} wasn't equal to {t.numpy()}" + # check that the received tensor is the same as the sent tensor + if rank == 0: + assert np.allclose( + t.numpy(), t2.numpy() + ), f"{t2.numpy()} wasn't equal to {t.numpy()}" + + print(f"rank {rank} passed") - print(f"rank {rank} passed") if __name__ == "__main__": - if getenv("HIP"): - devices = ["hip:0", "hip:1"] - else: - devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"] - world_size = len(devices) + if getenv("HIP"): + devices = ["hip:0", "hip:1"] + else: + devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"] + world_size = len(devices) - dist.init_oob(world_size) + dist.init_oob(world_size) - processes = [] - for rank, device in enumerate(devices): - processes.append(dist.spawn(rank, device, fn=run, args=())) - for p in processes: p.join() + processes = [] + for rank, device in enumerate(devices): + processes.append(dist.spawn(rank, device, fn=run, args=())) + for p in processes: + p.join() - # exit with error code if any of the processes failed - for p in processes: - if p.exitcode != 0: exit(p.exitcode) + # exit with error code if any of the processes failed + for p in processes: + if p.exitcode != 0: + exit(p.exitcode) diff --git a/test/external/external_benchmark_hip_compile.py b/test/external/external_benchmark_hip_compile.py index 2b1d48034..d27d5b436 100644 --- a/test/external/external_benchmark_hip_compile.py +++ b/test/external/external_benchmark_hip_compile.py @@ -9,23 +9,22 @@ from tinygrad.runtime.ops_gpu import compile_cl, CLDevice # issue is in https://github.com/ROCm-Developer-Tools/clr/ if __name__ == "__main__": - HIPDevice() - CLDevice() - - # warmup - name = "none"+str(random.randint(0, 1000000)) - compile_cl.__wrapped__(f"void {name}() {{}}") - print("compile cl warmed up") - compile_hip.__wrapped__(f"void {name}() {{}}") - print("compile hip warmed up") - - print("**** benchmark ****") - name = "none"+str(random.randint(0, 1000000)) - # this uses AMD_COMGR_ACTION_COMPILE_SOURCE_TO_BC, then it links the lib on the next step - with Timing("compile cl: "): compile_cl.__wrapped__(f"void {name}() {{}}") - # this uses AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, much slower - with Timing("compile hip: "): compile_hip.__wrapped__(f"void {name}() {{}}") - os._exit(0) - + HIPDevice() + CLDevice() + # warmup + name = "none" + str(random.randint(0, 1000000)) + compile_cl.__wrapped__(f"void {name}() {{}}") + print("compile cl warmed up") + compile_hip.__wrapped__(f"void {name}() {{}}") + print("compile hip warmed up") + print("**** benchmark ****") + name = "none" + str(random.randint(0, 1000000)) + # this uses AMD_COMGR_ACTION_COMPILE_SOURCE_TO_BC, then it links the lib on the next step + with Timing("compile cl: "): + compile_cl.__wrapped__(f"void {name}() {{}}") + # this uses AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, much slower + with Timing("compile hip: "): + compile_hip.__wrapped__(f"void {name}() {{}}") + os._exit(0) diff --git a/test/external/external_benchmark_load_stable_diffusion.py b/test/external/external_benchmark_load_stable_diffusion.py index 6ee0d0a32..84c625212 100644 --- a/test/external/external_benchmark_load_stable_diffusion.py +++ b/test/external/external_benchmark_load_stable_diffusion.py @@ -6,8 +6,11 @@ from examples.stable_diffusion import StableDiffusion # run "sudo purge" before testing on OS X to avoid the memory cache if __name__ == "__main__": - fn = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') - model = StableDiffusion() - with Timing(): - load_state_dict(model, torch_load(fn)['state_dict'], strict=False) - Device[Device.DEFAULT].synchronize() + fn = fetch( + "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt", + "sd-v1-4.ckpt", + ) + model = StableDiffusion() + with Timing(): + load_state_dict(model, torch_load(fn)["state_dict"], strict=False) + Device[Device.DEFAULT].synchronize() diff --git a/test/external/external_cl_half_max.py b/test/external/external_cl_half_max.py index 7cd6b0c50..f0021ef2a 100644 --- a/test/external/external_cl_half_max.py +++ b/test/external/external_cl_half_max.py @@ -1,13 +1,14 @@ from tinygrad.runtime.ops_gpu import CLDevice, CLProgram, compile_cl if __name__ == "__main__": - dev = CLDevice() - lib = compile_cl(""" + dev = CLDevice() + lib = compile_cl( + """ #pragma OPENCL EXTENSION cl_khr_fp16 : enable __kernel void test(__global half *out, __global half *a, __global half *b) { int gid = get_global_id(0); out[gid] = max(a[gid], b[gid]); } -""") - prg = CLProgram(dev, "test", lib) - +""" + ) + prg = CLProgram(dev, "test", lib) diff --git a/test/external/external_llama_eval.py b/test/external/external_llama_eval.py index ce1cbd800..12f6c3d0e 100644 --- a/test/external/external_llama_eval.py +++ b/test/external/external_llama_eval.py @@ -6,97 +6,136 @@ from examples.llama import LLaMa from tinygrad.tensor import Tensor from tinygrad import Device + class LLaMaAdaptor(BaseLM): - def __init__( - self, - model_size="7B", - model_gen=1, - device="", - quantize=False, - batch_size=1, - max_batch_size=1, - do_sample=False, - temperature=1.0, - checkpoint_path="", - tokenizer_path="", - ): - super().__init__() + def __init__( + self, + model_size="7B", + model_gen=1, + device="", + quantize=False, + batch_size=1, + max_batch_size=1, + do_sample=False, + temperature=1.0, + checkpoint_path="", + tokenizer_path="", + ): + super().__init__() - if batch_size is None: - batch_size = 1 - self.do_sample = do_sample - self.temperature = temperature - self._device = device + if batch_size is None: + batch_size = 1 + self.do_sample = do_sample + self.temperature = temperature + self._device = device - assert isinstance(model_gen, int) - assert isinstance(model_size, str) - assert isinstance(batch_size, int) - assert isinstance(checkpoint_path, str) - assert isinstance(tokenizer_path, str) + assert isinstance(model_gen, int) + assert isinstance(model_size, str) + assert isinstance(batch_size, int) + assert isinstance(checkpoint_path, str) + assert isinstance(tokenizer_path, str) - self.llama = LLaMa.build(checkpoint_path, tokenizer_path, model_gen, model_size, quantize) + self.llama = LLaMa.build( + checkpoint_path, tokenizer_path, model_gen, model_size, quantize + ) - @classmethod - def create_from_arg_string(cls, arg_string, additional_config=None): - kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")} - return cls(**kwargs, **additional_config) + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")} + return cls(**kwargs, **additional_config) - @property - def eot_token_id(self): - # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* - return self.llama.tokenizer.eos_id() + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.llama.tokenizer.eos_id() - @property - def max_length(self): - return 1024 + @property + def max_length(self): + return 1024 - @property - def max_gen_toks(self): - return 256 + @property + def max_gen_toks(self): + return 256 - @property - def batch_size(self): - return 1 + @property + def batch_size(self): + return 1 - @property - def device(self): - return self._device + @property + def device(self): + return self._device - def tok_encode(self, string: str): - return [self.llama.tokenizer.bos_id()] + self.llama.tokenizer.encode(string) + def tok_encode(self, string: str): + return [self.llama.tokenizer.bos_id()] + self.llama.tokenizer.encode(string) - def tok_decode(self, tokens): - return self.llama.tokenizer.decode(tokens) + def tok_decode(self, tokens): + return self.llama.tokenizer.decode(tokens) - def _model_call(self, inps): - Tensor.no_grad = True - return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy()) + def _model_call(self, inps): + Tensor.no_grad = True + return torch.Tensor(self.llama.model(Tensor(inps.numpy()), 0).numpy()) - def greedy_until(self, requests): - continuations = [] - for request in requests: - prompt, until = request[0], request[1]['until'] - output = self.llama.greedy_until(prompt, until, max_length=128, temperature=0.0) - continuations.append(output[len(prompt):]) - return continuations + def greedy_until(self, requests): + continuations = [] + for request in requests: + prompt, until = request[0], request[1]["until"] + output = self.llama.greedy_until( + prompt, until, max_length=128, temperature=0.0 + ) + continuations.append(output[len(prompt) :]) + return continuations - def _model_generate(self, context, max_length, eos_token_id): - raise NotImplementedError() + def _model_generate(self, context, max_length, eos_token_id): + raise NotImplementedError() -if __name__ == '__main__': - print(f"using {Device.DEFAULT} backend") - parser = argparse.ArgumentParser(description='Run LLaMA evals in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--size', type=str, default="7B", help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2") - parser.add_argument('--gen', type=int, default="1", help="Generation of the model to use [1, 2]") - parser.add_argument('--quantize', action='store_true', help="Quantize the weights to int8 in memory") - parser.add_argument('--eval', type=str, default="arc_easy", help="Run in evaluation mode") - parser.add_argument('--limit', type=int, default=None, help="Limit tests in eval") - parser.add_argument('--weights', type=str, default="./weights/LLaMa/", help="Location of the weights") - parser.add_argument('--tokenizer', type=str, default="./weights/LLaMa/tokenizer.model", help="Location of the tokenizer") - args = parser.parse_args() +if __name__ == "__main__": + print(f"using {Device.DEFAULT} backend") - # run eval and exit - adaptor = LLaMaAdaptor(model_gen=args.gen, model_size=args.size, quantize=args.quantize, checkpoint_path=args.weights, tokenizer_path=args.tokenizer, device="cpu") - results = evaluator.evaluate(adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit) - print(json.dumps(results, indent=2)) + parser = argparse.ArgumentParser( + description="Run LLaMA evals in tinygrad", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--size", + type=str, + default="7B", + help="Size of model to use [7B, 13B, 30B, 65B] for Gen 1, [7B, 13B] for Gen 2", + ) + parser.add_argument( + "--gen", type=int, default="1", help="Generation of the model to use [1, 2]" + ) + parser.add_argument( + "--quantize", action="store_true", help="Quantize the weights to int8 in memory" + ) + parser.add_argument( + "--eval", type=str, default="arc_easy", help="Run in evaluation mode" + ) + parser.add_argument("--limit", type=int, default=None, help="Limit tests in eval") + parser.add_argument( + "--weights", + type=str, + default="./weights/LLaMa/", + help="Location of the weights", + ) + parser.add_argument( + "--tokenizer", + type=str, + default="./weights/LLaMa/tokenizer.model", + help="Location of the tokenizer", + ) + args = parser.parse_args() + + # run eval and exit + adaptor = LLaMaAdaptor( + model_gen=args.gen, + model_size=args.size, + quantize=args.quantize, + checkpoint_path=args.weights, + tokenizer_path=args.tokenizer, + device="cpu", + ) + results = evaluator.evaluate( + adaptor, tasks.get_task_dict(args.eval.split(",")), False, 0, args.limit + ) + print(json.dumps(results, indent=2)) diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 01742126d..f7817ed12 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -1,6 +1,7 @@ import csv, pathlib, time, numpy as np from os import getenv import torch + torch.set_num_threads(1) import onnx from onnx.helper import tensor_dtype_to_np_dtype @@ -12,119 +13,177 @@ from tinygrad.tensor import Tensor from tinygrad import Device MODELS = { - "resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx", - "openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx", - "efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx", - "shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx", - "commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx", - - # broken in torch MPS - #"zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx", - # TypeError: BatchNormalization() got an unexpected keyword argument 'is_test' - #"densenet": "https://github.com/onnx/models/raw/main/vision/classification/densenet-121/model/densenet-3.onnx", - # AssertionError: only onnx version >= 10 supported for slice - #"bert": "https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-8.onnx", - # really slow - #"resnet18": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx", + "resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx", + "openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx", + "efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx", + "shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx", + "commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx", + # broken in torch MPS + # "zfnet": "https://github.com/onnx/models/raw/main/vision/classification/zfnet-512/model/zfnet512-9.onnx", + # TypeError: BatchNormalization() got an unexpected keyword argument 'is_test' + # "densenet": "https://github.com/onnx/models/raw/main/vision/classification/densenet-121/model/densenet-3.onnx", + # AssertionError: only onnx version >= 10 supported for slice + # "bert": "https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-8.onnx", + # really slow + # "resnet18": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx", } CSV = {} open_csv = None + def benchmark(mnm, nm, fxn): - tms = [] - for _ in range(3): - st = time.perf_counter_ns() - ret = fxn() - tms.append(time.perf_counter_ns() - st) - print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms") - CSV[nm] = min(tms)*1e-6 - return min(tms), ret + tms = [] + for _ in range(3): + st = time.perf_counter_ns() + ret = fxn() + tms.append(time.perf_counter_ns() - st) + print(f"{mnm:15s} {nm:25s} {min(tms)*1e-6:7.2f} ms") + CSV[nm] = min(tms) * 1e-6 + return min(tms), ret -#BASE = pathlib.Path(__file__).parents[2] / "weights" / "onnx" + +# BASE = pathlib.Path(__file__).parents[2] / "weights" / "onnx" BASE = pathlib.Path("/tmp/onnx") + + def benchmark_model(m, devices, validate_outs=False): - torch.manual_seed(1) - global open_csv, CSV - CSV = {"model": m} + torch.manual_seed(1) + global open_csv, CSV + CSV = {"model": m} - fn = fetch(MODELS[m]) - onnx_model = onnx.load(fn) - output_names = [out.name for out in onnx_model.graph.output] - excluded = {inp.name for inp in onnx_model.graph.initializer} - input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} - input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded} - #input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast - np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()} - assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}" + fn = fetch(MODELS[m]) + onnx_model = onnx.load(fn) + output_names = [out.name for out in onnx_model.graph.output] + excluded = {inp.name for inp in onnx_model.graph.initializer} + input_shapes = { + inp.name: tuple( + x.dim_value if x.dim_value != 0 else 1 + for x in inp.type.tensor_type.shape.dim + ) + for inp in onnx_model.graph.input + if inp.name not in excluded + } + input_types = { + inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) + for inp in onnx_model.graph.input + if inp.name not in excluded + } + # input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast + np_inputs = { + k: torch.randn(shp).numpy().astype(input_types[k]) + for k, shp in input_shapes.items() + } + assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}" - # print input names - if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded]) + # print input names + if DEBUG >= 2: + print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded]) - for device in devices: - Device.DEFAULT = device - inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) - benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()}) + for device in devices: + Device.DEFAULT = device + inputs = {k: Tensor(inp) for k, inp in np_inputs.items()} + tinygrad_model = get_run_onnx(onnx_model) + benchmark( + m, + f"tinygrad_{device.lower()}_jitless", + lambda: {k: v.numpy() for k, v in tinygrad_model(inputs).items()}, + ) - from tinygrad.jit import TinyJit - tinygrad_jitted_model = TinyJit(lambda **kwargs: {k:v.realize() for k,v in tinygrad_model(kwargs).items()}) - for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()} - benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821 - del inputs, tinygrad_model, tinygrad_jitted_model + from tinygrad.jit import TinyJit - try: - torch_model = convert(onnx_model) - torch_inputs = [torch.tensor(x) for x in np_inputs.values()] - benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs)) + tinygrad_jitted_model = TinyJit( + lambda **kwargs: {k: v.realize() for k, v in tinygrad_model(kwargs).items()} + ) + for _ in range(3): + {k: v.numpy() for k, v in tinygrad_jitted_model(**inputs).items()} + benchmark( + m, + f"tinygrad_{device.lower()}_jit", + lambda: {k: v.numpy() for k, v in tinygrad_jitted_model(**inputs).items()}, + ) # noqa: F821 + del inputs, tinygrad_model, tinygrad_jitted_model - torch_device = "mps" if OSX else "cuda" - torch_mps_model = torch_model.to(torch_device) - torch_mps_inputs = [x.to(torch_device) for x in torch_inputs] - benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs)) - except Exception as e: print(f"{m:16s}onnx2torch {type(e).__name__:>25}") - - # bench onnxruntime - ort_options = ort.SessionOptions() - ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - ort_options.log_severity_level = 3 # no warnings - for backend in ["CPU", "CUDA" if not OSX else "CoreML"]: # https://onnxruntime.ai/docs/execution-providers/ - provider = backend+"ExecutionProvider" - if provider not in ort.get_available_providers(): continue - ort_sess = ort.InferenceSession(str(fn), ort_options, [provider]) try: - benchmark(m, f"onnxruntime_{backend.lower()}", lambda: ort_sess.run(output_names, np_inputs)) - except Exception as e: print(f"{m:16s}onnxruntime_{backend.lower()} {type(e).__name__:>25}") - del ort_sess + torch_model = convert(onnx_model) + torch_inputs = [torch.tensor(x) for x in np_inputs.values()] + benchmark(m, "torch_cpu", lambda: torch_model(*torch_inputs)) - if validate_outs: - rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models - if m == "openpilot" and 'CUDA' in devices: rtol, atol = 0.1, 0.1 # TODO: why is this broken? - inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = get_run_onnx(onnx_model) - tinygrad_out = tinygrad_model(inputs) + torch_device = "mps" if OSX else "cuda" + torch_mps_model = torch_model.to(torch_device) + torch_mps_inputs = [x.to(torch_device) for x in torch_inputs] + benchmark( + m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs) + ) + except Exception as e: + print(f"{m:16s}onnx2torch {type(e).__name__:>25}") - ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) - onnx_out = ort_sess.run(output_names, np_inputs) - onnx_out = dict([*[(name,x) for name, x in zip(output_names, onnx_out)]]) + # bench onnxruntime + ort_options = ort.SessionOptions() + ort_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + ort_options.log_severity_level = 3 # no warnings + for backend in [ + "CPU", + "CUDA" if not OSX else "CoreML", + ]: # https://onnxruntime.ai/docs/execution-providers/ + provider = backend + "ExecutionProvider" + if provider not in ort.get_available_providers(): + continue + ort_sess = ort.InferenceSession(str(fn), ort_options, [provider]) + try: + benchmark( + m, + f"onnxruntime_{backend.lower()}", + lambda: ort_sess.run(output_names, np_inputs), + ) + except Exception as e: + print(f"{m:16s}onnxruntime_{backend.lower()} {type(e).__name__:>25}") + del ort_sess - assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol) - print(f"{m:16s}outputs validated with rtol={rtol:.1e}, atol={atol:.1e}") + if validate_outs: + rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models + if m == "openpilot" and "CUDA" in devices: + rtol, atol = 0.1, 0.1 # TODO: why is this broken? + inputs = {k: Tensor(inp) for k, inp in np_inputs.items()} + tinygrad_model = get_run_onnx(onnx_model) + tinygrad_out = tinygrad_model(inputs) - if open_csv is None: - open_csv = csv.DictWriter(open('onnx_inference_speed.csv', 'w', newline=''), fieldnames=list(CSV.keys())) - open_csv.writeheader() - open_csv.writerow(CSV) + ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) + onnx_out = ort_sess.run(output_names, np_inputs) + onnx_out = dict([*[(name, x) for name, x in zip(output_names, onnx_out)]]) + + assert_allclose(tinygrad_out, onnx_out, rtol=rtol, atol=atol) + print(f"{m:16s}outputs validated with rtol={rtol:.1e}, atol={atol:.1e}") + + if open_csv is None: + open_csv = csv.DictWriter( + open("onnx_inference_speed.csv", "w", newline=""), + fieldnames=list(CSV.keys()), + ) + open_csv.writeheader() + open_csv.writerow(CSV) + + +def assert_allclose(tiny_out: dict, onnx_out: dict, rtol=1e-5, atol=1e-5): + assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys() + for k in tiny_out.keys(): + tiny_v, onnx_v = tiny_out[k], onnx_out[k] + if tiny_v is None: + assert tiny_v == onnx_v + else: + np.testing.assert_allclose( + tiny_v.numpy(), + onnx_v, + rtol=rtol, + atol=atol, + err_msg=f"For tensor '{k}' in {tiny_out.keys()}", + ) -def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5): - assert len(tiny_out) == len(onnx_out) and tiny_out.keys() == onnx_out.keys() - for k in tiny_out.keys(): - tiny_v, onnx_v = tiny_out[k], onnx_out[k] - if tiny_v is None: assert tiny_v == onnx_v - else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}") if __name__ == "__main__": - devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CLANG"] - if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), devices, True) - else: - for m in MODELS: benchmark_model(m, devices, True) + devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CLANG"] + if getenv("MODEL", "") != "": + benchmark_model(getenv("MODEL", ""), devices, True) + else: + for m in MODELS: + benchmark_model(m, devices, True) diff --git a/test/external/external_multi_gpu.py b/test/external/external_multi_gpu.py index 542f1e621..bf8bce73a 100644 --- a/test/external/external_multi_gpu.py +++ b/test/external/external_multi_gpu.py @@ -6,74 +6,81 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import colored, Timing, getenv from tinygrad.device import Device -d0, d1 = f'{Device.DEFAULT}:0', f'{Device.DEFAULT}:1' +d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" + def sync(): - Device[d0].synchronize() - Device[d1].synchronize() + Device[d0].synchronize() + Device[d1].synchronize() + if __name__ == "__main__": - print("GPU devices", d0, d1) - sz = getenv("N", 1024*1024*256) # 1 GB + print("GPU devices", d0, d1) + sz = getenv("N", 1024 * 1024 * 256) # 1 GB - with Timing("GPU initial sync: "): sync() + with Timing("GPU initial sync: "): + sync() - with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"): - c0 = (Tensor.ones(sz, device="clang")/2).realize() - c1 = (Tensor.ones(sz, device="clang")/4).realize() - print(c0.lazydata.realized) - print(c1.lazydata.realized) + with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"): + c0 = (Tensor.ones(sz, device="clang") / 2).realize() + c1 = (Tensor.ones(sz, device="clang") / 4).realize() + print(c0.lazydata.realized) + print(c1.lazydata.realized) - with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - a0 = c0.to(d0).realize() - sync() - with Timing("CPU -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - b1 = c1.to(d1).realize() - sync() + with Timing("CPU -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + a0 = c0.to(d0).realize() + sync() + with Timing("CPU -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + b1 = c1.to(d1).realize() + sync() - # cross copy. this is (sometimes) going through the CPU - with Timing("0 -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - a1 = a0.to(d1).realize() - sync() - with Timing("1 -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - b0 = b1.to(d0).realize() - sync() + # cross copy. this is (sometimes) going through the CPU + with Timing("0 -> 1: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + a1 = a0.to(d1).realize() + sync() + with Timing("1 -> 0: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + b0 = b1.to(d0).realize() + sync() - # sum - with Timing("0+0 -> 0 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - ab0 = (a0 + b0).realize() - sync() - with Timing("1+1 -> 1 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - ab1 = (a1 + b1).realize() - sync() + # sum + with Timing("0+0 -> 0 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + ab0 = (a0 + b0).realize() + sync() + with Timing("1+1 -> 1 (sum): ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + ab1 = (a1 + b1).realize() + sync() - # cross device sum (does this work?) - with Timing(colored("0+1 -> 0 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - abx0 = (a0 + b1.to(d0)).realize() - sync() + # cross device sum (does this work?) + with Timing( + colored("0+1 -> 0 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec" + ): + abx0 = (a0 + b1.to(d0)).realize() + sync() - with Timing(colored("1+0 -> 1 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - abx1 = (b1 + a0.to(d1)).realize() - sync() + with Timing( + colored("1+0 -> 1 (sum): ", "red"), on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec" + ): + abx1 = (b1 + a0.to(d1)).realize() + sync() - # copy back - # NOTE: half of this slowness is caused by allocating memory on the CPU - with Timing("0 -> CPU: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - cc0 = ab0.numpy() - with Timing("1 -> CPU: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): - cc1 = ab1.numpy() + # copy back + # NOTE: half of this slowness is caused by allocating memory on the CPU + with Timing("0 -> CPU: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + cc0 = ab0.numpy() + with Timing("1 -> CPU: ", on_exit=lambda x: f", {(sz*4)/x:.2f} GB/sec"): + cc1 = ab1.numpy() - # same - print("testing") - np.testing.assert_allclose(cc0, cc1) + # same + print("testing") + np.testing.assert_allclose(cc0, cc1) - # same (cross) - print("testing (cross)") - np.testing.assert_allclose(cc0, abx0.numpy()) - np.testing.assert_allclose(cc0, abx1.numpy()) + # same (cross) + print("testing (cross)") + np.testing.assert_allclose(cc0, abx0.numpy()) + np.testing.assert_allclose(cc0, abx1.numpy()) - # devices - print(ab0) - print(ab1) - print(abx0) - print(abx1) + # devices + print(ab0) + print(ab1) + print(abx0) + print(abx1) diff --git a/test/external/external_osx_profiling.py b/test/external/external_osx_profiling.py index 6f9b215a5..268f7d836 100644 --- a/test/external/external_osx_profiling.py +++ b/test/external/external_osx_profiling.py @@ -7,18 +7,48 @@ a = CLBuffer(N, dtypes.float32) b = CLBuffer(N, dtypes.float32) c = CLBuffer(N, dtypes.float32) -prg = CLProgram("test", """__kernel void test(__global float *a, __global float *b, __global float *c) { +prg = CLProgram( + "test", + """__kernel void test(__global float *a, __global float *b, __global float *c) { int idx = get_global_id(0); a[idx] = b[idx] + c[idx]; -}""") -prg.clprgs[0](CL.cl_queue[0], [N,], None, a._buf, b._buf, c._buf) +}""", +) +prg.clprgs[0]( + CL.cl_queue[0], + [ + N, + ], + None, + a._buf, + b._buf, + c._buf, +) t1 = time.monotonic_ns() -e1 = prg.clprgs[0](CL.cl_queue[0], [N,], None, a._buf, b._buf, c._buf) +e1 = prg.clprgs[0]( + CL.cl_queue[0], + [ + N, + ], + None, + a._buf, + b._buf, + c._buf, +) CL.synchronize() t2 = time.monotonic_ns() time.sleep(3) t3 = time.monotonic_ns() -e2 = prg.clprgs[0](CL.cl_queue[0], [N,], None, a._buf, b._buf, c._buf) +e2 = prg.clprgs[0]( + CL.cl_queue[0], + [ + N, + ], + None, + a._buf, + b._buf, + c._buf, +) CL.synchronize() t4 = time.monotonic_ns() @@ -28,12 +58,12 @@ print(e1.profile.start) print(e1.profile.end) print(e1, e2) -print(t2-t1, e1.profile.end - e1.profile.start) -print(t4-t3, e2.profile.end - e2.profile.start) -print(t3-t2, e2.profile.queued-e1.profile.end) -print((t3-t2) / (e2.profile.start-e1.profile.end), "ratio") +print(t2 - t1, e1.profile.end - e1.profile.start) +print(t4 - t3, e2.profile.end - e2.profile.start) +print(t3 - t2, e2.profile.queued - e1.profile.end) +print((t3 - t2) / (e2.profile.start - e1.profile.end), "ratio") -print("ratio since boot", t1/e1.profile.start) +print("ratio since boot", t1 / e1.profile.start) print(e1.profile.start) print(e1.profile.end) diff --git a/test/external/external_test_embedding.py b/test/external/external_test_embedding.py index 9d6bd7f2b..acdbb501f 100644 --- a/test/external/external_test_embedding.py +++ b/test/external/external_test_embedding.py @@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Embedding if __name__ == "__main__": - vocab_size = 50257 - dim = 128 - test = Embedding(vocab_size, dim) - ret = test(Tensor([[1,2,3]])).numpy() + vocab_size = 50257 + dim = 128 + test = Embedding(vocab_size, dim) + ret = test(Tensor([[1, 2, 3]])).numpy() diff --git a/test/external/external_test_image.py b/test/external/external_test_image.py index 3e246eef7..c5b59f729 100644 --- a/test/external/external_test_image.py +++ b/test/external/external_test_image.py @@ -2,51 +2,55 @@ import os import unittest import numpy as np -if 'IMAGE' not in os.environ: - os.environ['IMAGE'] = '2' -os.environ['GPU'] = '1' -os.environ['OPT'] = '2' + +if "IMAGE" not in os.environ: + os.environ["IMAGE"] = "2" +os.environ["GPU"] = "1" +os.environ["OPT"] = "2" from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d + Tensor.no_grad = True + class TestImage(unittest.TestCase): - def test_create_image(self): - t = Tensor.ones(128, 128, 1) - t = t.reshape(128, 32, 4) + 3 - t.realize() - np.testing.assert_array_equal(t.numpy(), np.ones((128,32,4))*4) + def test_create_image(self): + t = Tensor.ones(128, 128, 1) + t = t.reshape(128, 32, 4) + 3 + t.realize() + np.testing.assert_array_equal(t.numpy(), np.ones((128, 32, 4)) * 4) - def test_sum_image(self): - t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3 - t1.realize() - t1 = t1.sum() - t1.realize() - assert t1.numpy() == 16*4*4*4, f"got {t1.numpy()}" + def test_sum_image(self): + t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3 + t1.realize() + t1 = t1.sum() + t1.realize() + assert t1.numpy() == 16 * 4 * 4 * 4, f"got {t1.numpy()}" - def test_add_image(self): - t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3 - t2 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 4 - t1.realize() - t2.realize() - t3 = t1 + t2 - t3.realize() - np.testing.assert_array_equal(t3.numpy(), np.ones((16,4,4))*9) + def test_add_image(self): + t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3 + t2 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 4 + t1.realize() + t2.realize() + t3 = t1 + t2 + t3.realize() + np.testing.assert_array_equal(t3.numpy(), np.ones((16, 4, 4)) * 9) - def test_padded_conv(self): - bs, in_chans, out_chans = 1,12,32 - tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) - tiny_dat = Tensor.ones(bs, 12, 64, 128) - tiny_conv(tiny_dat).realize() + def test_padded_conv(self): + bs, in_chans, out_chans = 1, 12, 32 + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) + tiny_dat = Tensor.ones(bs, 12, 64, 128) + tiny_conv(tiny_dat).realize() - def test_op_conv(self): - bs, in_chans, out_chans = 1,12,32 - tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) - tiny_dconv = Conv2d(out_chans, out_chans, 1, bias=None, padding=0) - tiny_dat = Tensor.ones(bs, 12, 64, 128) - p2 = tiny_conv(tiny_dat).relu() - p2 = tiny_dconv(p2) - p2.realize() + def test_op_conv(self): + bs, in_chans, out_chans = 1, 12, 32 + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) + tiny_dconv = Conv2d(out_chans, out_chans, 1, bias=None, padding=0) + tiny_dat = Tensor.ones(bs, 12, 64, 128) + p2 = tiny_conv(tiny_dat).relu() + p2 = tiny_dconv(p2) + p2.realize() -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_jit_on_models.py b/test/external/external_test_jit_on_models.py index 0eaff63c1..383bfea5c 100644 --- a/test/external/external_test_jit_on_models.py +++ b/test/external/external_test_jit_on_models.py @@ -8,36 +8,72 @@ from test.helpers import derandomize_model from examples.llama import Transformer + def helper_test_jitted_correctness(gen, train, train_jit): - nojit = train(*gen()).numpy() - for _ in range(5): jit = train_jit(*gen()).numpy() - np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5) + nojit = train(*gen()).numpy() + for _ in range(5): + jit = train_jit(*gen()).numpy() + np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5) + class TestJittedModels(unittest.TestCase): - def test_jitted_tiny_llama(self): - old_type = Tensor.default_type - Tensor.default_type = dtypes.float16 + def test_jitted_tiny_llama(self): + old_type = Tensor.default_type + Tensor.default_type = dtypes.float16 - args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} - model = Transformer(**args_tiny) - derandomize_model(model) - def test(t): return model(t, 0).realize() + args_tiny = { + "dim": 1024, + "hidden_dim": 1024, + "n_heads": 8, + "n_layers": 8, + "norm_eps": 1e-05, + "vocab_size": 1000, + } + model = Transformer(**args_tiny) + derandomize_model(model) - @TinyJit - def test_jit(t): return model(t, 0).realize() - helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit) - Tensor.default_type = old_type + def test(t): + return model(t, 0).realize() - @unittest.skipUnless(not CI, "huge for CI") - def test_jitted_stable_diffusion(self): - from examples.stable_diffusion import UNetModel - model = UNetModel() - derandomize_model(model) - def test(t, t2): return model(t, 801, t2).realize() + @TinyJit + def test_jit(t): + return model(t, 0).realize() + + helper_test_jitted_correctness( + lambda: ( + Tensor( + [ + [ + 1, + ] + ] + ), + ), + test, + test_jit, + ) + Tensor.default_type = old_type + + @unittest.skipUnless(not CI, "huge for CI") + def test_jitted_stable_diffusion(self): + from examples.stable_diffusion import UNetModel + + model = UNetModel() + derandomize_model(model) + + def test(t, t2): + return model(t, 801, t2).realize() + + @TinyJit + def test_jit(t, t2): + return model(t, 801, t2).realize() + + helper_test_jitted_correctness( + lambda: (Tensor.randn(1, 4, 16, 16), Tensor.randn(1, 77, 768)), + test, + test_jit, + ) - @TinyJit - def test_jit(t, t2): return model(t, 801, t2).realize() - helper_test_jitted_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test, test_jit) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index c3059203b..1c0558fd4 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -8,205 +8,243 @@ from tinygrad.helpers import getenv, CI from tinygrad import Device # pip3 install tabulate -pytest_plugins = 'onnx.backend.test.report', +pytest_plugins = ("onnx.backend.test.report",) from extra.onnx import get_run_onnx -class TinygradModel(BackendRep): - def __init__(self, run_onnx, input_names): - super().__init__() - self.fxn = run_onnx - self.input_names = input_names - def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]: - real_inputs = {k:v for k,v in zip(self.input_names, inputs)} - ret = self.fxn(real_inputs, debug=True) - return tuple(x.numpy() if isinstance(x, Tensor) else [i.numpy() for i in x] if isinstance(x, list) else np.array(x) for x in ret.values()) +class TinygradModel(BackendRep): + def __init__(self, run_onnx, input_names): + super().__init__() + self.fxn = run_onnx + self.input_names = input_names + + def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]: + real_inputs = {k: v for k, v in zip(self.input_names, inputs)} + ret = self.fxn(real_inputs, debug=True) + return tuple( + x.numpy() + if isinstance(x, Tensor) + else [i.numpy() for i in x] + if isinstance(x, list) + else np.array(x) + for x in ret.values() + ) + class TinygradBackend(Backend): - @classmethod - def prepare(cls, model, device): - input_all = [x.name for x in model.graph.input] - input_initializer = [x.name for x in model.graph.initializer] - net_feed_input = [x for x in input_all if x not in input_initializer] - print("prepare", cls, device, net_feed_input) - run_onnx = get_run_onnx(model) - return TinygradModel(run_onnx, net_feed_input) + @classmethod + def prepare(cls, model, device): + input_all = [x.name for x in model.graph.input] + input_initializer = [x.name for x in model.graph.initializer] + net_feed_input = [x for x in input_all if x not in input_initializer] + print("prepare", cls, device, net_feed_input) + run_onnx = get_run_onnx(model) + return TinygradModel(run_onnx, net_feed_input) + + @classmethod + def supports_device(cls, device: str) -> bool: + return device == "CPU" - @classmethod - def supports_device(cls, device: str) -> bool: - return device == "CPU" backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__) # no support for reduce with multiply (needs llop) -backend_test.exclude('test_reduce_prod_*') +backend_test.exclude("test_reduce_prod_*") # TODO figure out why it's returning wrong values, geohotstan's uneducated guess is it's due to imprecision from float64 (double) -> float32 # see Type Constraints: https://onnx.ai/onnx/operators/onnx_aionnxpreviewtraining_Adam.html#type-constraints -backend_test.exclude('test_adam_multiple_cpu') -backend_test.exclude('test_nesterov_momentum_cpu') +backend_test.exclude("test_adam_multiple_cpu") +backend_test.exclude("test_nesterov_momentum_cpu") # we only support float32 -backend_test.exclude('uint8') -backend_test.exclude('uint16') -backend_test.exclude('uint32') -backend_test.exclude('uint64') -backend_test.exclude('int8') -backend_test.exclude('int16') -backend_test.exclude('float64') -backend_test.exclude('string') +backend_test.exclude("uint8") +backend_test.exclude("uint16") +backend_test.exclude("uint32") +backend_test.exclude("uint64") +backend_test.exclude("int8") +backend_test.exclude("int16") +backend_test.exclude("float64") +backend_test.exclude("string") -backend_test.exclude('test_pow_types_int*') -backend_test.exclude('test_cast_*') -backend_test.exclude('test_castlike_*') -backend_test.exclude('test_convinteger_*') -backend_test.exclude('test_matmulinteger_*') +backend_test.exclude("test_pow_types_int*") +backend_test.exclude("test_cast_*") +backend_test.exclude("test_castlike_*") +backend_test.exclude("test_convinteger_*") +backend_test.exclude("test_matmulinteger_*") -backend_test.exclude('test_reduce_log_sum_exp*') # dependent on actual float64 implementation for backends -backend_test.exclude('test_operator_add*') # dependent on float64 math. Without it values default to 0 or inf +backend_test.exclude( + "test_reduce_log_sum_exp*" +) # dependent on actual float64 implementation for backends +backend_test.exclude( + "test_operator_add*" +) # dependent on float64 math. Without it values default to 0 or inf # we don't support indexes # backend_test.exclude('test_argmax_*') # Needs more work: select_last_index # backend_test.exclude('test_argmin_*') # Needs more work: select_last_index -backend_test.exclude('test_nonzero_*') +backend_test.exclude("test_nonzero_*") # no support for mod -backend_test.exclude('test_mod_*') +backend_test.exclude("test_mod_*") # no boolean ops (2d, 3d, 4d) -backend_test.exclude('test_bitshift_*') +backend_test.exclude("test_bitshift_*") # no scatternd gathernd -backend_test.exclude('test_gathernd_*') -backend_test.exclude('test_scatternd_*') +backend_test.exclude("test_gathernd_*") +backend_test.exclude("test_scatternd_*") # no quantize -backend_test.exclude('test_dynamicquantizelinear_*') -backend_test.exclude('test_qlinearmatmul_*') -backend_test.exclude('test_qlinearconv_*') -backend_test.exclude('test_quantizelinear_*') +backend_test.exclude("test_dynamicquantizelinear_*") +backend_test.exclude("test_qlinearmatmul_*") +backend_test.exclude("test_qlinearconv_*") +backend_test.exclude("test_quantizelinear_*") # no rnn -backend_test.exclude('test_gru_*') -backend_test.exclude('test_rnn_*') -backend_test.exclude('test_lstm_*') -backend_test.exclude('test_simple_rnn_*') +backend_test.exclude("test_gru_*") +backend_test.exclude("test_rnn_*") +backend_test.exclude("test_lstm_*") +backend_test.exclude("test_simple_rnn_*") # no control flow # control flow uses AttributeProto.GRAPH -backend_test.exclude('test_if_*') -backend_test.exclude('test_loop*') -backend_test.exclude('test_range_float_type_positive_delta_expanded_cpu') # requires loop -backend_test.exclude('test_affine_grid_2d_align_corners_expanded_cpu') -backend_test.exclude('test_affine_grid_2d_expanded_cpu') -backend_test.exclude('test_affine_grid_3d_align_corners_expanded_cpu') -backend_test.exclude('test_affine_grid_3d_expanded_cpu') -backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu') +backend_test.exclude("test_if_*") +backend_test.exclude("test_loop*") +backend_test.exclude( + "test_range_float_type_positive_delta_expanded_cpu" +) # requires loop +backend_test.exclude("test_affine_grid_2d_align_corners_expanded_cpu") +backend_test.exclude("test_affine_grid_2d_expanded_cpu") +backend_test.exclude("test_affine_grid_3d_align_corners_expanded_cpu") +backend_test.exclude("test_affine_grid_3d_expanded_cpu") +backend_test.exclude("test_range_int32_type_negative_delta_expanded_cpu") # unsupported (strange) ops -backend_test.exclude('test_bitwise_*') -backend_test.exclude('test_blackmanwindow_*') -backend_test.exclude('test_bernoulli_*') -backend_test.exclude('test_cumsum_*') -backend_test.exclude('test_det_*') +backend_test.exclude("test_bitwise_*") +backend_test.exclude("test_blackmanwindow_*") +backend_test.exclude("test_bernoulli_*") +backend_test.exclude("test_cumsum_*") +backend_test.exclude("test_det_*") -backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support -backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support +backend_test.exclude("test_tril_zero_cpu") # TODO: zero array tril support +backend_test.exclude("test_triu_zero_cpu") # TODO: zero array triu support -backend_test.exclude('test_col2im_*') -backend_test.exclude('test_hammingwindow_*') -backend_test.exclude('test_hannwindow_*') -backend_test.exclude('test_hardmax_*') -backend_test.exclude('test_gridsample_*') -backend_test.exclude('test_dft_*') -backend_test.exclude('test_einsum_*') -backend_test.exclude('test_strnorm_*') -backend_test.exclude('test_unique_*') -backend_test.exclude('test_sequence_*') -backend_test.exclude('test_nonmaxsuppression_*') -backend_test.exclude('test_reversesequence_*') -backend_test.exclude('test_roialign_*') -backend_test.exclude('test_top_k_*') -backend_test.exclude('test_tfidfvectorizer_*') -backend_test.exclude('test_stft_*') -backend_test.exclude('test_melweightmatrix_*') +backend_test.exclude("test_col2im_*") +backend_test.exclude("test_hammingwindow_*") +backend_test.exclude("test_hannwindow_*") +backend_test.exclude("test_hardmax_*") +backend_test.exclude("test_gridsample_*") +backend_test.exclude("test_dft_*") +backend_test.exclude("test_einsum_*") +backend_test.exclude("test_strnorm_*") +backend_test.exclude("test_unique_*") +backend_test.exclude("test_sequence_*") +backend_test.exclude("test_nonmaxsuppression_*") +backend_test.exclude("test_reversesequence_*") +backend_test.exclude("test_roialign_*") +backend_test.exclude("test_top_k_*") +backend_test.exclude("test_tfidfvectorizer_*") +backend_test.exclude("test_stft_*") +backend_test.exclude("test_melweightmatrix_*") # more strange ops -backend_test.exclude('test_basic_deform_conv_*') -backend_test.exclude('test_deform_conv_*') -backend_test.exclude('test_lppool_*') -backend_test.exclude('test_depthtospace_*') -backend_test.exclude('test_spacetodepth_*') -backend_test.exclude('test_scan*') -backend_test.exclude('test_split_to_sequence_*') -backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic -backend_test.exclude('test_resize_downsample_sizes_cubic_*') # unsure how to implement cubic -backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to implement cubic -backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic +backend_test.exclude("test_basic_deform_conv_*") +backend_test.exclude("test_deform_conv_*") +backend_test.exclude("test_lppool_*") +backend_test.exclude("test_depthtospace_*") +backend_test.exclude("test_spacetodepth_*") +backend_test.exclude("test_scan*") +backend_test.exclude("test_split_to_sequence_*") +backend_test.exclude( + "test_resize_downsample_scales_cubic_*" +) # unsure how to implement cubic +backend_test.exclude( + "test_resize_downsample_sizes_cubic_*" +) # unsure how to implement cubic +backend_test.exclude( + "test_resize_upsample_scales_cubic_*" +) # unsure how to implement cubic +backend_test.exclude( + "test_resize_upsample_sizes_cubic_*" +) # unsure how to implement cubic # rest of the failing tests -backend_test.exclude('test_regex_*') # does not support string Tensors -backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0, also allowzero -backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented -backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented -backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip -backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string -backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string +backend_test.exclude("test_regex_*") # does not support string Tensors +backend_test.exclude( + "test_reshape_allowzero_reordered_cpu" +) # reshaping to shape with 0, also allowzero +backend_test.exclude( + "test_resize_downsample_scales_linear_antialias_cpu" +) # antialias not implemented +backend_test.exclude( + "test_resize_downsample_sizes_linear_antialias_cpu" +) # antialias not implemented +backend_test.exclude( + "test_resize_tf_crop_and_resize_cpu" +) # unsure about fill value after clip +backend_test.exclude( + "test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu" +) # bad data type string +backend_test.exclude( + "test_ai_onnx_ml_label_encoder_tensor_mapping_cpu" +) # bad data type string # issue 1556 https://github.com/tinygrad/tinygrad/issues/1556 -backend_test.exclude('test_isinf_cpu') -backend_test.exclude('test_isinf_negative_cpu') -backend_test.exclude('test_isinf_positive_cpu') -backend_test.exclude('test_isinf_float16_cpu') -backend_test.exclude('test_isnan_float16_cpu') -backend_test.exclude('test_isnan_cpu') +backend_test.exclude("test_isinf_cpu") +backend_test.exclude("test_isinf_negative_cpu") +backend_test.exclude("test_isinf_positive_cpu") +backend_test.exclude("test_isinf_float16_cpu") +backend_test.exclude("test_isnan_float16_cpu") +backend_test.exclude("test_isnan_cpu") # issue 1791 fast math messes with these https://github.com/tinygrad/tinygrad/issues/1791 -backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu') -backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu') -backend_test.exclude('test_resize_upsample_sizes_nearest_cpu') +backend_test.exclude("test_resize_upsample_sizes_nearest_axes_2_3_cpu") +backend_test.exclude("test_resize_upsample_sizes_nearest_axes_3_2_cpu") +backend_test.exclude("test_resize_upsample_sizes_nearest_cpu") # issue 2067 potentially also a fastmath issue https://github.com/tinygrad/tinygrad/issues/2067 -if Device.DEFAULT in ['METAL']: - backend_test.exclude('test_maxpool_2d_pads_cpu') - backend_test.exclude('test_maxpool_2d_same_lower_cpu') +if Device.DEFAULT in ["METAL"]: + backend_test.exclude("test_maxpool_2d_pads_cpu") + backend_test.exclude("test_maxpool_2d_same_lower_cpu") -if Device.DEFAULT in ['GPU', 'METAL']: - backend_test.exclude('test_mish_cpu') # weird inaccuracy - backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy - backend_test.exclude('test_eyelike_with_dtype_cpu') # backend does not support dtype: Double +if Device.DEFAULT in ["GPU", "METAL"]: + backend_test.exclude("test_mish_cpu") # weird inaccuracy + backend_test.exclude("test_mish_expanded_cpu") # weird inaccuracy + backend_test.exclude( + "test_eyelike_with_dtype_cpu" + ) # backend does not support dtype: Double # Segfaults in CI, GPU requires cl_khr_fp16 -if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI: - backend_test.exclude('test_max_float16_cpu') - backend_test.exclude('test_min_float16_cpu') +if Device.DEFAULT in ["LLVM", "CUDA", "GPU"] and CI: + backend_test.exclude("test_max_float16_cpu") + backend_test.exclude("test_min_float16_cpu") # error: casting to type 'half' is not allowed -backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu') +backend_test.exclude("test_dequantizelinear_e4m3fn_float16_cpu") # TODO: this somehow passes in CI but does not pass if run locally -if Device.DEFAULT in ['GPU', 'METAL', 'LLVM', 'CLANG']: - backend_test.exclude('test_MaxPool3d_stride_padding_cpu') +if Device.DEFAULT in ["GPU", "METAL", "LLVM", "CLANG"]: + backend_test.exclude("test_MaxPool3d_stride_padding_cpu") # disable model tests for now since they are slow if not getenv("MODELTESTS"): - for x in backend_test.test_suite: - if 'OnnxBackendRealModelTest' in str(type(x)): - backend_test.exclude(str(x).split(" ")[0]) + for x in backend_test.test_suite: + if "OnnxBackendRealModelTest" in str(type(x)): + backend_test.exclude(str(x).split(" ")[0]) else: - # model tests all pass! - backend_test.include('test_resnet50') - backend_test.include('test_inception_v1') - backend_test.include('test_inception_v2') - backend_test.include('test_densenet121') - backend_test.include('test_shufflenet') - backend_test.include('test_squeezenet') - backend_test.include('test_bvlc_alexnet') - backend_test.include('test_zfnet512') - backend_test.include('test_vgg19') + # model tests all pass! + backend_test.include("test_resnet50") + backend_test.include("test_inception_v1") + backend_test.include("test_inception_v2") + backend_test.include("test_densenet121") + backend_test.include("test_shufflenet") + backend_test.include("test_squeezenet") + backend_test.include("test_bvlc_alexnet") + backend_test.include("test_zfnet512") + backend_test.include("test_vgg19") globals().update(backend_test.enable_report().test_cases) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index a24ea38a6..def1f872f 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -2,10 +2,11 @@ import os import torch + if "OPT" not in os.environ: - os.environ["OPT"] = "2" + os.environ["OPT"] = "2" else: - assert int(os.environ["OPT"]) >= 2, "test is broken with OPT=0 or OPT=1" + assert int(os.environ["OPT"]) >= 2, "test is broken with OPT=0 or OPT=1" import gc import numpy as np @@ -19,22 +20,36 @@ from tinygrad.helpers import GlobalCounters from tinygrad.lazy import PUSH_PERMUTES from tinygrad.jit import CacheCollector + class CLCache: - def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): - self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {} - def __enter__(self): - if self.preclear: - gc.collect() - for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: - x.realize() - GlobalCounters.reset() - CacheCollector.start(self.var_vals) - print("cache: entering") - def __exit__(self, type, value, traceback): - cache = CacheCollector.finish() - print(f"cache: exiting with size {len(cache)}", f"allowed {self.allowed}" if self.allowed is not None else "") - if self.allowed is not None: - assert len(cache) <= self.allowed and (not self.strict or len(cache) == self.allowed), f"used too many kernels! {len(cache)} > {self.allowed}" + def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): + self.allowed, self.strict, self.preclear, self.var_vals = ( + allowed, + strict, + preclear, + var_vals if var_vals is not None else {}, + ) + + def __enter__(self): + if self.preclear: + gc.collect() + for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: + x.realize() + GlobalCounters.reset() + CacheCollector.start(self.var_vals) + print("cache: entering") + + def __exit__(self, type, value, traceback): + cache = CacheCollector.finish() + print( + f"cache: exiting with size {len(cache)}", + f"allowed {self.allowed}" if self.allowed is not None else "", + ) + if self.allowed is not None: + assert len(cache) <= self.allowed and ( + not self.strict or len(cache) == self.allowed + ), f"used too many kernels! {len(cache)} > {self.allowed}" + from extra.models.convnext import ConvNeXt from extra.models.efficientnet import EfficientNet @@ -42,338 +57,420 @@ from extra.models.resnet import ResNet18 from extra.models.vit import ViT from tinygrad.nn.state import get_parameters + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestInferenceMinKernels(unittest.TestCase): - def setUp(self): - self.training_old = Tensor.training - Tensor.training = False - def tearDown(self): - Tensor.training = self.training_old + def setUp(self): + self.training_old = Tensor.training + Tensor.training = False - @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") - def test_convnext(self): - model = ConvNeXt() - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - img = Tensor.randn(1, 3, 224, 224) - with CLCache(129): - model(img).realize() + def tearDown(self): + Tensor.training = self.training_old - def test_enet(self): - model = EfficientNet(getenv("ENET_NUM", 0), has_se=False) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - img = Tensor.randn(1, 3, 224, 224) - with CLCache(51): - model.forward(img).realize() + @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") + def test_convnext(self): + model = ConvNeXt() + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(129): + model(img).realize() - def test_enet_se(self): - model = EfficientNet(getenv("ENET_NUM", 0), has_se=True) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - img = Tensor.randn(1, 3, 224, 224) - # TODO: this seems very high - with CLCache(115): - model.forward(img).realize() + def test_enet(self): + model = EfficientNet(getenv("ENET_NUM", 0), has_se=False) + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(51): + model.forward(img).realize() - def test_resnet(self): - model = ResNet18() - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - img = Tensor.randn(1, 3, 224, 224) - with CLCache(26): - model.forward(img).realize() + def test_enet_se(self): + model = EfficientNet(getenv("ENET_NUM", 0), has_se=True) + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + # TODO: this seems very high + with CLCache(115): + model.forward(img).realize() - def test_vit(self): - model = ViT(embed_dim=192, num_heads=3) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - img = Tensor.randn(1, 3, 224, 224) - with CLCache(222): # NOTE: this is way too high - out = model.forward(img) - assert len(CacheCollector.cache) == 0, "ViT prerealized?" - out.realize() + def test_resnet(self): + model = ResNet18() + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(26): + model.forward(img).realize() + + def test_vit(self): + model = ViT(embed_dim=192, num_heads=3) + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(222): # NOTE: this is way too high + out = model.forward(img) + assert len(CacheCollector.cache) == 0, "ViT prerealized?" + out.realize() + + def test_llama(self): + from examples.llama import Transformer + + args_tiny = { + "dim": 512, + "hidden_dim": 1024, + "n_heads": 8, + "n_layers": 4, + "norm_eps": 1e-05, + "vocab_size": 1000, + } + model = Transformer(**args_tiny) + for p in get_parameters(model): + p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + inp = Tensor([[1, 2, 3, 4]]) + with CLCache(100): + model(inp, 0).realize() - def test_llama(self): - from examples.llama import Transformer - args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} - model = Transformer(**args_tiny) - for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - inp = Tensor([[1,2,3,4]]) - with CLCache(100): - model(inp, 0).realize() @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptBinOp(unittest.TestCase): - def _test_no_binop_rerun(self, f1, f2=None, allowed=1): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - c = f1(a, b) - if f2 is not None: d = f2(a, b) - c.realize() - if f2 is not None: d.realize() - assert len(CacheCollector.cache) == allowed, "binop was rerun!" - if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5) + def _test_no_binop_rerun(self, f1, f2=None, allowed=1): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = f1(a, b) + if f2 is not None: + d = f2(a, b) + c.realize() + if f2 is not None: + d.realize() + assert len(CacheCollector.cache) == allowed, "binop was rerun!" + if f2 is not None: + np.testing.assert_allclose( + c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5 + ) - def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1)) - def test_no_binop_rerun_alt(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(16, 16, 1), lambda a,b: a*b) - def test_no_binop_rerun_reduce_broadcast(self): return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2) - @unittest.skip("this test started failing with the new change, based movementop issue") - def test_no_binop_rerun_transposed(self): return self._test_no_binop_rerun(lambda a,b: (a.T*b.T).T, lambda a,b: a*b) - def test_no_binop_rerun_mid_reshape(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(256)+a.reshape(256)) + def test_no_binop_rerun(self): + return self._test_no_binop_rerun( + lambda a, b: a * b, lambda a, b: (a * b).reshape(16, 16, 1) + ) + + def test_no_binop_rerun_alt(self): + return self._test_no_binop_rerun( + lambda a, b: (a * b).reshape(16, 16, 1), lambda a, b: a * b + ) + + def test_no_binop_rerun_reduce_broadcast(self): + return self._test_no_binop_rerun( + lambda a, b: a.sum() + b, lambda a, b: a.sum().reshape(1, 1) + b, allowed=2 + ) + + @unittest.skip( + "this test started failing with the new change, based movementop issue" + ) + def test_no_binop_rerun_transposed(self): + return self._test_no_binop_rerun(lambda a, b: (a.T * b.T).T, lambda a, b: a * b) + + def test_no_binop_rerun_mid_reshape(self): + return self._test_no_binop_rerun( + lambda a, b: (a * b).reshape(256) + a.reshape(256) + ) + + # currently non working tests + # def test_no_binop_rerun_preshape(self): return self._test_no_binop_rerun(lambda a,b: a.reshape(16, 16, 1)*b.reshape(16, 16, 1), lambda a,b: a*b) + # def test_no_binop_rerun_reduce(self): return self._test_no_binop_rerun(lambda a,b: (a*b).sum(), lambda a,b: (a*b).reshape(16, 16, 1).sum()) + # def test_no_binop_rerun_reduce_alt(self): return self._test_no_binop_rerun(lambda a,b: a.sum(1)+b[0], lambda a,b: a.sum(1).reshape(1,16)+b[0]) - # currently non working tests - #def test_no_binop_rerun_preshape(self): return self._test_no_binop_rerun(lambda a,b: a.reshape(16, 16, 1)*b.reshape(16, 16, 1), lambda a,b: a*b) - #def test_no_binop_rerun_reduce(self): return self._test_no_binop_rerun(lambda a,b: (a*b).sum(), lambda a,b: (a*b).reshape(16, 16, 1).sum()) - #def test_no_binop_rerun_reduce_alt(self): return self._test_no_binop_rerun(lambda a,b: a.sum(1)+b[0], lambda a,b: a.sum(1).reshape(1,16)+b[0]) @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptReduceLoop(unittest.TestCase): - @unittest.skip("this is broken") - def test_loop_left(self): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - t = a.sum(0) - b = t.reshape(16,1).expand(16,16).sum(0) - c = (t+b) - c.realize() - assert len(CacheCollector.cache) == 2, "loop left fusion broken" + @unittest.skip("this is broken") + def test_loop_left(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + t = a.sum(0) + b = t.reshape(16, 1).expand(16, 16).sum(0) + c = t + b + c.realize() + assert len(CacheCollector.cache) == 2, "loop left fusion broken" + + def test_loop_right(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + t = a.sum(0) + b = t.reshape(16, 1).expand(16, 16).sum(0) + c = b + t + c.realize() + assert len(CacheCollector.cache) == 2, "loop right fusion broken" - def test_loop_right(self): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - t = a.sum(0) - b = t.reshape(16,1).expand(16,16).sum(0) - c = (b+t) - c.realize() - assert len(CacheCollector.cache) == 2, "loop right fusion broken" @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptWChild(unittest.TestCase): - def test_unrealized_child(self): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - c = (a*b).sum() - d = c+1 - e = c+2 # noqa: F841 - d.realize() - assert len(CacheCollector.cache) == 2, "don't fuse if you have children" + def test_unrealized_child(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = (a * b).sum() + d = c + 1 + e = c + 2 # noqa: F841 + d.realize() + assert len(CacheCollector.cache) == 2, "don't fuse if you have children" + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOpt(unittest.TestCase): - def test_muladd(self): - a,b,c = [Tensor.randn(2,2).realize() for _ in range(3)] - na,nb,nc = a.numpy(),b.numpy(),c.numpy() - with CLCache(allowed=1): - d = a * b + c - d.realize() - np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5) + def test_muladd(self): + a, b, c = [Tensor.randn(2, 2).realize() for _ in range(3)] + na, nb, nc = a.numpy(), b.numpy(), c.numpy() + with CLCache(allowed=1): + d = a * b + c + d.realize() + np.testing.assert_allclose(d.numpy(), na * nb + nc, rtol=1e-5) - def test_fold_reduce_elementwise(self): - img = Tensor.ones(32) - addme = Tensor.ones(1) - with CLCache(): - ret = img.sum() + addme - ret.realize() - assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise" - assert ret.item() == 33 + def test_fold_reduce_elementwise(self): + img = Tensor.ones(32) + addme = Tensor.ones(1) + with CLCache(): + ret = img.sum() + addme + ret.realize() + assert ( + len(CacheCollector.cache) == 1 + ), "optimizer didn't fold reduce/elementwise" + assert ret.item() == 33 - def test_fold_batchnorm(self): - with Tensor.train(): - img = Tensor.ones(1,32,4,4) - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): - img_bn = bn(img).realize() - print(img_bn) - assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" + def test_fold_batchnorm(self): + with Tensor.train(): + img = Tensor.ones(1, 32, 4, 4) + bn = nn.BatchNorm2d(32, track_running_stats=False) + with CLCache(): + img_bn = bn(img).realize() + print(img_bn) + assert ( + len(CacheCollector.cache) == 3 + ), f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" - def test_fold_conv_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,4,4) - c1 = nn.Conv2d(3,32,3) - opt = optim.SGD(get_parameters(c1)) - with CLCache(): - opt.zero_grad() - c1(img).relu().sum().backward() - opt.step() - # TODO: this should be 4, but the sum output child stays around - # with pushing_permutes it can be 3 - # TODO: broken with optim fixes - assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" + def test_fold_conv_sgd(self): + with Tensor.train(): + img = Tensor.ones(2, 3, 4, 4) + c1 = nn.Conv2d(3, 32, 3) + opt = optim.SGD(get_parameters(c1)) + with CLCache(): + opt.zero_grad() + c1(img).relu().sum().backward() + opt.step() + # TODO: this should be 4, but the sum output child stays around + # with pushing_permutes it can be 3 + # TODO: broken with optim fixes + assert len(CacheCollector.cache) in [ + 4, + 5, + 6, + ], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" - def test_fold_2convs_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,16,3,bias=False) - c2 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2])) - with CLCache(allowed=9): - opt.zero_grad() - c2(c1(img).relu()).relu().sum().backward() - opt.step() + def test_fold_2convs_sgd(self): + with Tensor.train(): + img = Tensor.ones(2, 3, 64, 64) + c1 = nn.Conv2d(3, 16, 3, bias=False) + c2 = nn.Conv2d(16, 32, 3, bias=False) + opt = optim.SGD(get_parameters([c1, c2])) + with CLCache(allowed=9): + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + opt.step() - def test_fold_4convs_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,4,3,bias=False) - c2 = nn.Conv2d(4,8,3,bias=False) - c3 = nn.Conv2d(8,16,3,bias=False) - c4 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2, c3, c4])) - with CLCache(allowed=19): - opt.zero_grad() - c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - opt.step() + def test_fold_4convs_sgd(self): + with Tensor.train(): + img = Tensor.ones(2, 3, 64, 64) + c1 = nn.Conv2d(3, 4, 3, bias=False) + c2 = nn.Conv2d(4, 8, 3, bias=False) + c3 = nn.Conv2d(8, 16, 3, bias=False) + c4 = nn.Conv2d(16, 32, 3, bias=False) + opt = optim.SGD(get_parameters([c1, c2, c3, c4])) + with CLCache(allowed=19): + opt.zero_grad() + c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() + opt.step() - def test_fold_conv_batchnorm_sgd(self): - with Tensor.train(): - img = Tensor.ones(1,3,4,4) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - opt = optim.SGD(get_parameters([c1, bn])) - with CLCache(allowed=17): # this is too high - img_bn = bn(c1(img)).elu().sum() - opt.zero_grad() - img_bn.backward() - opt.step() + def test_fold_conv_batchnorm_sgd(self): + with Tensor.train(): + img = Tensor.ones(1, 3, 4, 4) + c1 = nn.Conv2d(3, 32, 3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + opt = optim.SGD(get_parameters([c1, bn])) + with CLCache(allowed=17): # this is too high + img_bn = bn(c1(img)).elu().sum() + opt.zero_grad() + img_bn.backward() + opt.step() - def test_fold_conv_batchnorm_notrain(self): - img = Tensor.ones(1,3,8,8) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - # precache the bn - bn(c1(img)).relu().realize() - with CLCache(): - bn(c1(img)).relu().realize() - assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}" + def test_fold_conv_batchnorm_notrain(self): + img = Tensor.ones(1, 3, 8, 8) + c1 = nn.Conv2d(3, 32, 3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + # precache the bn + bn(c1(img)).relu().realize() + with CLCache(): + bn(c1(img)).relu().realize() + assert ( + len(CacheCollector.cache) == 1 + ), f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}" - def test_fold_conv_batchnorm(self): - with Tensor.train(): - img = Tensor.ones(1,3,8,8) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): - img_conv = bn(c1(img)).relu().realize() - print(img_conv) - assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" + def test_fold_conv_batchnorm(self): + with Tensor.train(): + img = Tensor.ones(1, 3, 8, 8) + c1 = nn.Conv2d(3, 32, 3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + with CLCache(): + img_conv = bn(c1(img)).relu().realize() + print(img_conv) + assert ( + len(CacheCollector.cache) == 4 + ), f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" - def test_fold_conv_elu(self): - img = Tensor.ones(1,4,8,8) - c1 = nn.Conv2d(4, 4, kernel_size=3) - c2 = nn.Conv2d(4, 4, kernel_size=3) - with CLCache(): - img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize() - print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu" + def test_fold_conv_elu(self): + img = Tensor.ones(1, 4, 8, 8) + c1 = nn.Conv2d(4, 4, kernel_size=3) + c2 = nn.Conv2d(4, 4, kernel_size=3) + with CLCache(): + img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize() + print(img_conv) + assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu" - def test_fold_conv_relu(self): - img = Tensor.ones(1,4,8,8) - c1 = nn.Conv2d(4, 4, kernel_size=3) - c2 = nn.Conv2d(4, 4, kernel_size=3) - with CLCache(): - img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() - print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" + def test_fold_conv_relu(self): + img = Tensor.ones(1, 4, 8, 8) + c1 = nn.Conv2d(4, 4, kernel_size=3) + c2 = nn.Conv2d(4, 4, kernel_size=3) + with CLCache(): + img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() + print(img_conv) + assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" - def test_fold_conv_relu_nobias(self): - img = Tensor.ones(1,4,8,8) - c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) - c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) - with CLCache(): - img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() - print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" + def test_fold_conv_relu_nobias(self): + img = Tensor.ones(1, 4, 8, 8) + c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + with CLCache(): + img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() + print(img_conv) + assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" - def test_permute_was_pushed(self): - a = Tensor.randn(16, 16, 16) - with CLCache(2): - c = a.sum(2) - d = c.permute(1,0).contiguous() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) - if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + def test_permute_was_pushed(self): + a = Tensor.randn(16, 16, 16) + with CLCache(2): + c = a.sum(2) + d = c.permute(1, 0).contiguous() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + a.numpy().sum(2).transpose(1, 0), d.numpy(), rtol=1e-3, atol=1e-5 + ) + if PUSH_PERMUTES: + assert cache_len == 1, "permute wasn't pushed!" - def test_permute_was_pushed_through_contract_reshape(self): - a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(2): - c = a.sum(-1) - d = c.reshape(16,16).permute(1,0).contiguous() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) - if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + def test_permute_was_pushed_through_contract_reshape(self): + a = Tensor.randn(4, 4, 4, 4, 4) + with CLCache(2): + c = a.sum(-1) + d = c.reshape(16, 16).permute(1, 0).contiguous() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + a.numpy().sum(-1).reshape(16, 16).transpose(1, 0), + d.numpy(), + rtol=1e-3, + atol=1e-5, + ) + if PUSH_PERMUTES: + assert cache_len == 1, "permute wasn't pushed!" - def test_permute_was_pushed_through_contractw1s_reshape(self): - a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(2): - c = a.sum(-1) - d = c.reshape(16,1,16).permute(2,1,0).contiguous() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5) - if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + def test_permute_was_pushed_through_contractw1s_reshape(self): + a = Tensor.randn(4, 4, 4, 4, 4) + with CLCache(2): + c = a.sum(-1) + d = c.reshape(16, 1, 16).permute(2, 1, 0).contiguous() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + a.numpy().sum(-1).reshape(16, 1, 16).transpose(2, 1, 0), + d.numpy(), + rtol=1e-3, + atol=1e-5, + ) + if PUSH_PERMUTES: + assert cache_len == 1, "permute wasn't pushed!" - # TODO: push permute through expansion reshape - @unittest.skip("expansion can't push expand permute yet") - @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") - def test_permute_was_pushed_through_expand_reshape(self): - a = Tensor.randn(16, 16, 16) - with CLCache(): - c = a.sum(2) - d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5) - if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + # TODO: push permute through expansion reshape + @unittest.skip("expansion can't push expand permute yet") + @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") + def test_permute_was_pushed_through_expand_reshape(self): + a = Tensor.randn(16, 16, 16) + with CLCache(): + c = a.sum(2) + d = c.reshape(4, 4, 4, 4).permute(2, 3, 0, 1).contiguous() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + a.numpy().sum(2).transpose(1, 0).reshape(4, 4, 4, 4), + d.numpy(), + rtol=1e-3, + atol=1e-5, + ) + if PUSH_PERMUTES: + assert cache_len == 1, "permute wasn't pushed!" - @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") - def test_no_reduceop_rerun(self): - a = Tensor.randn(16, 16, 16) - with CLCache(): - c = a.sum(2) - d = a.sum(2).permute(1,0) - c.realize() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) - assert cache_len == 1, "reduceop was rerun!" + @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") + def test_no_reduceop_rerun(self): + a = Tensor.randn(16, 16, 16) + with CLCache(): + c = a.sum(2) + d = a.sum(2).permute(1, 0) + c.realize() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + c.numpy().transpose(1, 0), d.numpy(), rtol=1e-3, atol=1e-5 + ) + assert cache_len == 1, "reduceop was rerun!" - @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") - def test_no_reduceop_rerun_alt(self): - a = Tensor.randn(16, 16, 16) - with CLCache(): - c = a.sum(2).permute(1,0) - d = a.sum(2) - c.realize() - d.realize() - cache_len = len(CacheCollector.cache) - np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5) - assert cache_len == 1, "reduceop was rerun!" + @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") + def test_no_reduceop_rerun_alt(self): + a = Tensor.randn(16, 16, 16) + with CLCache(): + c = a.sum(2).permute(1, 0) + d = a.sum(2) + c.realize() + d.realize() + cache_len = len(CacheCollector.cache) + np.testing.assert_allclose( + c.numpy(), d.numpy().transpose(1, 0), rtol=1e-3, atol=1e-5 + ) + assert cache_len == 1, "reduceop was rerun!" - def test_fold_with_contiguous(self): - a = Tensor.randn(16, 16, 16) - b = Tensor.randn(16, 16) - with CLCache(1): - c = (a.sum(2).contiguous() + b).contiguous() - c.realize() + def test_fold_with_contiguous(self): + a = Tensor.randn(16, 16, 16) + b = Tensor.randn(16, 16) + with CLCache(1): + c = (a.sum(2).contiguous() + b).contiguous() + c.realize() - def test_expand_reduce_is_folded_on_same_axis(self): - for axis in [0, 1]: - for n in [4, 8, 16]: - b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) - with CLCache(allowed=2): - a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) - a.realize() - np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + def test_expand_reduce_is_folded_on_same_axis(self): + for axis in [0, 1]: + for n in [4, 8, 16]: + b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) + with CLCache(allowed=2): + a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) + a.realize() + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) - def test_expand_reduce_is_not_folded_on_different_axes(self): - axis1, axis2 = 0, 1 - for n in [4, 8, 16]: - b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) - with CLCache(allowed=2): - a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) - a.realize() - np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + def test_expand_reduce_is_not_folded_on_different_axes(self): + axis1, axis2 = 0, 1 + for n in [4, 8, 16]: + b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + with CLCache(allowed=2): + a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + a.realize() + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_optim.py b/test/external/external_test_optim.py index 2851f1198..d13c2e4d7 100644 --- a/test/external/external_test_optim.py +++ b/test/external/external_test_optim.py @@ -7,69 +7,84 @@ from tinygrad.tensor import Tensor from tinygrad.nn.optim import LAMB np.random.seed(1337) -x_init = np.random.randn(1,4).astype(np.float32) -W_init = np.random.randn(4,4).astype(np.float32) -m_init = np.random.randn(1,4).astype(np.float32) +x_init = np.random.randn(1, 4).astype(np.float32) +W_init = np.random.randn(4, 4).astype(np.float32) +m_init = np.random.randn(1, 4).astype(np.float32) + class TinyNet: - def __init__(self): - self.x = Tensor(x_init.copy(), requires_grad=True) - self.W = Tensor(W_init.copy(), requires_grad=True) - self.m = Tensor(m_init.copy()) + def __init__(self): + self.x = Tensor(x_init.copy(), requires_grad=True) + self.W = Tensor(W_init.copy(), requires_grad=True) + self.m = Tensor(m_init.copy()) + + def forward(self): + out = self.x.matmul(self.W).relu() + out = out.log_softmax(1) + out = out.mul(self.m).add(self.m).sum() + return out - def forward(self): - out = self.x.matmul(self.W).relu() - out = out.log_softmax(1) - out = out.mul(self.m).add(self.m).sum() - return out class TinyNetTF: - def __init__(self): - self.x = tf.Variable(x_init.copy(), trainable=True) - self.W = tf.Variable(W_init.copy(), trainable=True) - self.m = tf.constant(m_init.copy()) + def __init__(self): + self.x = tf.Variable(x_init.copy(), trainable=True) + self.W = tf.Variable(W_init.copy(), trainable=True) + self.m = tf.constant(m_init.copy()) + + def forward(self): + out = tf.matmul(self.x, self.W) + out = tf.nn.relu(out) + out = tf.nn.log_softmax(out, axis=1) + out = tf.multiply(out, self.m) + self.m + out = tf.reduce_sum(out) + return out - def forward(self): - out = tf.matmul(self.x, self.W) - out = tf.nn.relu(out) - out = tf.nn.log_softmax(out, axis=1) - out = tf.multiply(out, self.m) + self.m - out = tf.reduce_sum(out) - return out def step(optim, steps=1, kwargs={}): - net = TinyNet() - optim = optim([net.x, net.W], **kwargs) - for _ in range(steps): - out = net.forward() - optim.zero_grad() - out.backward() - optim.step() - return net.x.detach().numpy(), net.W.detach().numpy() + net = TinyNet() + optim = optim([net.x, net.W], **kwargs) + for _ in range(steps): + out = net.forward() + optim.zero_grad() + out.backward() + optim.step() + return net.x.detach().numpy(), net.W.detach().numpy() + def step_tf(optim, steps=1, kwargs={}): - net = TinyNetTF() - optim = optim(**kwargs) - for _ in range(steps): - with tf.GradientTape() as tape: - out = net.forward() - grads = tape.gradient(out, [net.x, net.W]) - optim.apply_gradients(zip(grads, [net.x, net.W])) - return net.x.numpy(), net.W.numpy() + net = TinyNetTF() + optim = optim(**kwargs) + for _ in range(steps): + with tf.GradientTape() as tape: + out = net.forward() + grads = tape.gradient(out, [net.x, net.W]) + optim.apply_gradients(zip(grads, [net.x, net.W])) + return net.x.numpy(), net.W.numpy() + class ExternalTestOptim(unittest.TestCase): - def _test_optim(self, tinygrad_optim, tensorflow_optim, steps, opts, atol, rtol): - for x,y in zip(step(tinygrad_optim, steps, kwargs=opts), - step_tf(tensorflow_optim, steps, kwargs=opts)): - np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) + def _test_optim(self, tinygrad_optim, tensorflow_optim, steps, opts, atol, rtol): + for x, y in zip( + step(tinygrad_optim, steps, kwargs=opts), + step_tf(tensorflow_optim, steps, kwargs=opts), + ): + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) - def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol) + def _test_lamb(self, steps, opts, atol, rtol): + self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol) - def test_lamb(self): self._test_lamb(1, {'lr': 0.001}, 1e-5, 0) - def test_lamb_high_lr(self): self._test_lamb(1, {'lr': 10}, 1e-5, 1e-5) + def test_lamb(self): + self._test_lamb(1, {"lr": 0.001}, 1e-5, 0) - def test_multistep_lamb(self): self._test_lamb(10, {'lr': 0.001}, 1e-5, 0) - def test_multistep_lamb_high_lr(self): self._test_lamb(10, {'lr': 10}, 1e-5, 3e-4) + def test_lamb_high_lr(self): + self._test_lamb(1, {"lr": 10}, 1e-5, 1e-5) -if __name__ == '__main__': - unittest.main() + def test_multistep_lamb(self): + self._test_lamb(10, {"lr": 0.001}, 1e-5, 0) + + def test_multistep_lamb_high_lr(self): + self._test_lamb(10, {"lr": 10}, 1e-5, 3e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 1ae0023ac..9bfbd0279 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -7,46 +7,64 @@ from tinygrad.nn.state import get_state_dict from tinygrad.device import Compiled, Allocator from tinygrad.helpers import Profiling + class FakeProgram: - def __init__(self, name:str, prg:bytes): pass - def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass + def __init__(self, name: str, prg: bytes): + pass + + def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): + pass + class FakeAllocator(Allocator): - def _alloc(self, sz): return None - def copyin(self, dest, src:memoryview): pass + def _alloc(self, sz): + return None + + def copyin(self, dest, src: memoryview): + pass + class TestLLaMASpeed(unittest.TestCase): - @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") - def test_llama_compile(self): - backup_program = Device[Device.DEFAULT].runtime - backup_allocator = Device[Device.DEFAULT].allocator - Device[Device.DEFAULT].runtime = FakeProgram - Device[Device.DEFAULT].allocator = FakeAllocator() + @unittest.skipIf( + not isinstance(Device[Device.DEFAULT], Compiled), + "only test for compiled backends", + ) + def test_llama_compile(self): + backup_program = Device[Device.DEFAULT].runtime + backup_allocator = Device[Device.DEFAULT].allocator + Device[Device.DEFAULT].runtime = FakeProgram + Device[Device.DEFAULT].allocator = FakeAllocator() - print("testing llama python run time") - model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"]) - print("built model") - # assign fake tensors to the values - for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype)) - print("assigned empty tensors, doing warmup") + print("testing llama python run time") + model = Transformer(**MODEL_PARAMS["1"]["7B"]["args"]) + print("built model") + # assign fake tensors to the values + for v in get_state_dict(model).values(): + v.assign(Tensor.empty(*v.shape, dtype=v.dtype)) + print("assigned empty tensors, doing warmup") - def run_llama(st, empty_method_cache=True): - if empty_method_cache: Device[Device.DEFAULT].get_runner.cache_clear() - tms = [time.perf_counter()] - for i in range(10): - model(Tensor([[1,2,3,4]]), i).realize() - tms.append(time.perf_counter()) - timings = [(tms[i+1]-tms[i])*1000 for i in range(len(tms)-1)] - print(f"{st:15s} mean runtime: {sum(timings)/len(timings):7.2f}ms, runs: ", ", ".join(f'{x:7.2f}' for x in timings)) + def run_llama(st, empty_method_cache=True): + if empty_method_cache: + Device[Device.DEFAULT].get_runner.cache_clear() + tms = [time.perf_counter()] + for i in range(10): + model(Tensor([[1, 2, 3, 4]]), i).realize() + tms.append(time.perf_counter()) + timings = [(tms[i + 1] - tms[i]) * 1000 for i in range(len(tms) - 1)] + print( + f"{st:15s} mean runtime: {sum(timings)/len(timings):7.2f}ms, runs: ", + ", ".join(f"{x:7.2f}" for x in timings), + ) - run_llama("codegen") - run_llama("methodcache", False) + run_llama("codegen") + run_llama("methodcache", False) - with Profiling(sort='time', frac=0.1): - run_llama("profile") + with Profiling(sort="time", frac=0.1): + run_llama("profile") - Device[Device.DEFAULT].runtime = backup_program - Device[Device.DEFAULT].allocator = backup_allocator + Device[Device.DEFAULT].runtime = backup_program + Device[Device.DEFAULT].allocator = backup_allocator -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_uops_graphing.py b/test/external/external_test_uops_graphing.py index c6394803f..94df19869 100644 --- a/test/external/external_test_uops_graphing.py +++ b/test/external/external_test_uops_graphing.py @@ -6,39 +6,42 @@ from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.graph import graph_uops from tinygrad.nn import Conv2d + class TestUopsGraph(unittest.TestCase): - def test_matmul(self): - N = 1024 - a = Tensor.rand(N,N) - b = Tensor.rand(N,N) - si = (a@b).lazydata.schedule()[-1] - lin = Linearizer(si.ast) - lin.hand_coded_optimizations() - print(lin.colored_shape()) - uops = lin.linearize().uops - graph_uops(uops) - for u in uops: print(u) - print(OpenCLRenderer("matmul", uops)[0]) + def test_matmul(self): + N = 1024 + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + si = (a @ b).lazydata.schedule()[-1] + lin = Linearizer(si.ast) + lin.hand_coded_optimizations() + print(lin.colored_shape()) + uops = lin.linearize().uops + graph_uops(uops) + for u in uops: + print(u) + print(OpenCLRenderer("matmul", uops)[0]) - def test_reduce(self): - a = Tensor.rand(1024*1024) - si = a.sum().lazydata.schedule()[-1] - lin = Linearizer(si.ast) - lin.hand_coded_optimizations() - uops = lin.linearize().uops - graph_uops(uops) - #print(OpenCLRenderer("reduce", uops)[0]) + def test_reduce(self): + a = Tensor.rand(1024 * 1024) + si = a.sum().lazydata.schedule()[-1] + lin = Linearizer(si.ast) + lin.hand_coded_optimizations() + uops = lin.linearize().uops + graph_uops(uops) + # print(OpenCLRenderer("reduce", uops)[0]) - def test_conv(self): - x = Tensor.rand(1,3,16,16) - c = Conv2d(3, 16, (3,3)) - si = c(x).elu().lazydata.schedule()[-1] - lin = Linearizer(si.ast) - lin.hand_coded_optimizations() - uops = lin.linearize().uops - graph_uops(uops) - print(lin.colored_shape()) - print(OpenCLRenderer("conv", uops)[0]) + def test_conv(self): + x = Tensor.rand(1, 3, 16, 16) + c = Conv2d(3, 16, (3, 3)) + si = c(x).elu().lazydata.schedule()[-1] + lin = Linearizer(si.ast) + lin.hand_coded_optimizations() + uops = lin.linearize().uops + graph_uops(uops) + print(lin.colored_shape()) + print(OpenCLRenderer("conv", uops)[0]) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_whisper_librispeech.py b/test/external/external_test_whisper_librispeech.py index 32c891ba6..d9be4e251 100644 --- a/test/external/external_test_whisper_librispeech.py +++ b/test/external/external_test_whisper_librispeech.py @@ -9,75 +9,84 @@ import numpy as np from whisper.normalizers import EnglishTextNormalizer from examples.whisper import init_whisper, transcribe_waveform + class TestWhisperLibriSpeech(unittest.TestCase): - # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb - # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22 - # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch - def test_en_tiny(self): - run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749) + # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb + # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22 + # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch + def test_en_tiny(self): + run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749) - def test_tiny(self): - run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187) + def test_tiny(self): + run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187) - def test_en_base(self): - run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505) + def test_en_base(self): + run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505) + + def test_en_small(self): + run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228) - def test_en_small(self): - run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228) def run_evaluation(model_name, tinygrad_expected_wer, reference_wer): - dataset = LibriSpeech() - batch_size=16 - loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) + dataset = LibriSpeech() + batch_size = 16 + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) - model, enc = init_whisper(model_name, batch_size=batch_size) + model, enc = init_whisper(model_name, batch_size=batch_size) - hypotheses = [] - references = [] + hypotheses = [] + references = [] - for audio, texts in tqdm.tqdm(loader): - transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True) - hypotheses.extend(transcriptions) - references.extend(texts) + for audio, texts in tqdm.tqdm(loader): + transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True) + hypotheses.extend(transcriptions) + references.extend(texts) - normalizer = EnglishTextNormalizer() - normalized_hypotheses = [normalizer(text) for text in hypotheses] - normalized_references = [normalizer(text) for text in references] - wer = jiwer.wer(normalized_hypotheses, normalized_references) + normalizer = EnglishTextNormalizer() + normalized_hypotheses = [normalizer(text) for text in hypotheses] + normalized_references = [normalizer(text) for text in references] + wer = jiwer.wer(normalized_hypotheses, normalized_references) + + np.testing.assert_almost_equal(wer, tinygrad_expected_wer) + print(f"tinygrad WER {wer} vs reference WER {reference_wer}") + del model, enc - np.testing.assert_almost_equal(wer, tinygrad_expected_wer) - print(f'tinygrad WER {wer} vs reference WER {reference_wer}') - del model, enc class LibriSpeech(torch.utils.data.Dataset): - def __init__(self): - dir = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech" - if not os.path.exists(dir): - os.makedirs(dir) + def __init__(self): + dir = ( + pathlib.Path(__file__).parent.parent.parent + / "extra" + / "datasets" + / "librispeech" + ) + if not os.path.exists(dir): + os.makedirs(dir) - self.dataset = torchaudio.datasets.LIBRISPEECH( - root=dir, - url="test-clean", - download=True, - ) + self.dataset = torchaudio.datasets.LIBRISPEECH( + root=dir, + url="test-clean", + download=True, + ) - def __len__(self): - return len(self.dataset) + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + audio, sample_rate, text, _, _, _ = self.dataset[item] + assert sample_rate == 16000 + return pad_or_trim_tensor(audio[0]), text - def __getitem__(self, item): - audio, sample_rate, text, _, _, _ = self.dataset[item] - assert sample_rate == 16000 - return pad_or_trim_tensor(audio[0]), text def pad_or_trim_tensor(tensor, target_len=480000): - curr_len = len(tensor) - if curr_len == target_len: - return tensor - elif curr_len < target_len: - return torch.cat((tensor, torch.zeros(target_len - curr_len))) - else: - return tensor[:target_len] + curr_len = len(tensor) + if curr_len == target_len: + return tensor + elif curr_len < target_len: + return torch.cat((tensor, torch.zeros(target_len - curr_len))) + else: + return tensor[:target_len] -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_yolo.py b/test/external/external_test_yolo.py index f28f23aa5..4f047fcdb 100644 --- a/test/external/external_test_yolo.py +++ b/test/external/external_test_yolo.py @@ -6,27 +6,35 @@ import cv2 from examples.yolov3 import Darknet, infer, show_labels from tinygrad.helpers import fetch -chicken_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/Chicken.jpg')) -car_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/car.jpg')) +chicken_img = cv2.imread( + str(Path(__file__).parent.parent / "models/efficientnet/Chicken.jpg") +) +car_img = cv2.imread(str(Path(__file__).parent.parent / "models/efficientnet/car.jpg")) + class TestYOLO(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = Darknet(fetch("https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg").read_bytes()) - print("Loading weights file (237MB). This might take a while…") - cls.model.load_weights("https://pjreddie.com/media/files/yolov3.weights") + @classmethod + def setUpClass(cls): + cls.model = Darknet( + fetch( + "https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg" + ).read_bytes() + ) + print("Loading weights file (237MB). This might take a while…") + cls.model.load_weights("https://pjreddie.com/media/files/yolov3.weights") - @classmethod - def tearDownClass(cls): - del cls.model + @classmethod + def tearDownClass(cls): + del cls.model - def test_chicken(self): - labels = show_labels(infer(self.model, chicken_img), confidence=0.56) - self.assertEqual(labels, ["bird"]) + def test_chicken(self): + labels = show_labels(infer(self.model, chicken_img), confidence=0.56) + self.assertEqual(labels, ["bird"]) - def test_car(self): - labels = show_labels(infer(self.model, car_img)) - self.assertEqual(labels, ["car"]) + def test_car(self): + labels = show_labels(infer(self.model, car_img)) + self.assertEqual(labels, ["car"]) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/external_test_yolov8.py b/test/external/external_test_yolov8.py index c98a4266e..cb3844889 100644 --- a/test/external/external_test_yolov8.py +++ b/test/external/external_test_yolov8.py @@ -1,5 +1,11 @@ import numpy as np -from examples.yolov8 import YOLOv8, get_variant_multiples, preprocess, postprocess, label_predictions +from examples.yolov8 import ( + YOLOv8, + get_variant_multiples, + preprocess, + postprocess, + label_predictions, +) import unittest import io, cv2 import onnxruntime as ort @@ -7,63 +13,102 @@ import ultralytics from tinygrad.nn.state import safe_load, load_state_dict from tinygrad.helpers import fetch + class TestYOLOv8(unittest.TestCase): - def test_all_load_weights(self): - for variant in ['n', 's', 'm', 'l', 'x']: - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')) - load_state_dict(TinyYolov8, state_dict) - print(f'successfully loaded weights for yolov{variant}') + def test_all_load_weights(self): + for variant in ["n", "s", "m", "l", "x"]: + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + state_dict = safe_load( + fetch( + f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors" + ) + ) + load_state_dict(TinyYolov8, state_dict) + print(f"successfully loaded weights for yolov{variant}") - def test_predictions(self): - test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg'] - variant = 'n' - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')) - load_state_dict(TinyYolov8, state_dict) + def test_predictions(self): + test_image_urls = [ + "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg", + "https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg", + ] + variant = "n" + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + state_dict = safe_load( + fetch( + f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors" + ) + ) + load_state_dict(TinyYolov8, state_dict) - for i in range(len(test_image_urls)): - img = cv2.imdecode(np.frombuffer(fetch(test_image_urls[i]).read_bytes(), np.uint8), 1) - test_image = preprocess([img]) - predictions = TinyYolov8(test_image) - post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img]) - labels = label_predictions(post_predictions) - assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1} + for i in range(len(test_image_urls)): + img = cv2.imdecode( + np.frombuffer(fetch(test_image_urls[i]).read_bytes(), np.uint8), 1 + ) + test_image = preprocess([img]) + predictions = TinyYolov8(test_image) + post_predictions = postprocess( + preds=predictions, img=test_image, orig_imgs=[img] + ) + labels = label_predictions(post_predictions) + assert ( + labels == {5: 1, 0: 4, 11: 1} + if i == 0 + else labels == {0: 13, 29: 1, 32: 1} + ) - def test_forward_pass_torch_onnx(self): - variant = 'n' - weights_location = fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors') - weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension - weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx" + def test_forward_pass_torch_onnx(self): + variant = "n" + weights_location = fetch( + f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors" + ) + weights_location_pt = fetch( + f"https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt", + name=f"yolov8{variant}.pt", + ) # it needs the pt extension + weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx" - # the ultralytics export prints a lot of unneccesary things - if not weights_location_onnx.is_file(): - model = ultralytics.YOLO(model=weights_location_pt, task='Detect') - model.export(format="onnx",imgsz=[640, 480]) + # the ultralytics export prints a lot of unneccesary things + if not weights_location_onnx.is_file(): + model = ultralytics.YOLO(model=weights_location_pt, task="Detect") + model.export(format="onnx", imgsz=[640, 480]) - depth, width, ratio = get_variant_multiples(variant) - TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(weights_location) - load_state_dict(TinyYolov8, state_dict) + depth, width, ratio = get_variant_multiples(variant) + TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) + state_dict = safe_load(weights_location) + load_state_dict(TinyYolov8, state_dict) - image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] - orig_image = [cv2.imdecode(image_location[0], 1)] + image_location = [ + np.frombuffer( + io.BytesIO( + fetch( + "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg" + ).read_bytes() + ).read(), + np.uint8, + ) + ] + orig_image = [cv2.imdecode(image_location[0], 1)] - input_image = preprocess(orig_image) + input_image = preprocess(orig_image) - onnx_session = ort.InferenceSession(weights_location_onnx) - onnx_input_name = onnx_session.get_inputs()[0].name - onnx_output_name = onnx_session.get_outputs()[0].name - onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()}) + onnx_session = ort.InferenceSession(weights_location_onnx) + onnx_input_name = onnx_session.get_inputs()[0].name + onnx_output_name = onnx_session.get_outputs()[0].name + onnx_output = onnx_session.run( + [onnx_output_name], {onnx_input_name: input_image.numpy()} + ) - tiny_output = TinyYolov8(input_image) + tiny_output = TinyYolov8(input_image) - # currently rtol is 0.025 because there is a 1-2% difference in our predictions - # because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch. - # This difference does not make a difference "visually". - np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025) + # currently rtol is 0.025 because there is a 1-2% difference in our predictions + # because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch. + # This difference does not make a difference "visually". + np.testing.assert_allclose( + onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025 + ) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index ab53c1e3e..f7dcf957c 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -11,92 +11,101 @@ from tinygrad.lazy import vars_from_ast device = Device[Device.DEFAULT] + def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): - if rawbufs is None: rawbufs = bufs_from_lin(lin) - if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)} + if rawbufs is None: + rawbufs = bufs_from_lin(lin) + if var_vals is None: + var_vals = {v: v.min for v in vars_from_ast(lin.ast)} - # TODO: images needs required_optimization - try: - if isinstance(device, Compiled): - prg = device.to_program(lin) - else: - prg = device.get_runner(lin.ast) - except Exception: - print(lin.ast) - traceback.print_exc() - print("COMPILE FAILED!!") - return "COMPILE_ERROR" + # TODO: images needs required_optimization + try: + if isinstance(device, Compiled): + prg = device.to_program(lin) + else: + prg = device.get_runner(lin.ast) + except Exception: + print(lin.ast) + traceback.print_exc() + print("COMPILE FAILED!!") + return "COMPILE_ERROR" - try: - prg.exec(rawbufs, var_vals) - except Exception: - print(lin.ast) - traceback.print_exc() - print("EXEC FAILED!!") - return "EXEC_ERROR" + try: + prg.exec(rawbufs, var_vals) + except Exception: + print(lin.ast) + traceback.print_exc() + print("EXEC FAILED!!") + return "EXEC_ERROR" - return "PASS" + return "PASS" def fuzz_linearizer(lin: Linearizer): - random.seed(42) - np.random.seed(42) - print_tree(lin.ast) - print(lin.colored_shape()) - rawbufs = bufs_from_lin(lin) - - seen_uops = {} - ground_truth = None - while 1: - if len(seen_uops) >= 20: break # enough for this kernel - actions = get_linearizer_actions(lin, include_0=False) - if not actions: break - lin = random.choice(list(actions.values())) - if lin.applied_opts: print(f"applied action: {lin.applied_opts[-1]}") - - # stop if kernel uops repeat - tuops = tuplize_uops(lin.linearize().uops) - if tuops in seen_uops: break - seen_uops[tuops] = tuple(lin.applied_opts) - + random.seed(42) + np.random.seed(42) + print_tree(lin.ast) print(lin.colored_shape()) - # get a new output buffer - rawbufs[0] = type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype) - var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)} - if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": - print(f"{lin.applied_opts=}") - return msg + rawbufs = bufs_from_lin(lin) - result = rawbufs[0].toCPU() - if ground_truth is None: - ground_truth = result - else: - try: - np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2) - except AssertionError: - print(lin.ast) - traceback.print_exc() - print(f"{lin.applied_opts=}") - return "NOT_ALLCLOSE" - return "PASS" + seen_uops = {} + ground_truth = None + while 1: + if len(seen_uops) >= 20: + break # enough for this kernel + actions = get_linearizer_actions(lin, include_0=False) + if not actions: + break + lin = random.choice(list(actions.values())) + if lin.applied_opts: + print(f"applied action: {lin.applied_opts[-1]}") + + # stop if kernel uops repeat + tuops = tuplize_uops(lin.linearize().uops) + if tuops in seen_uops: + break + seen_uops[tuops] = tuple(lin.applied_opts) + + print(lin.colored_shape()) + # get a new output buffer + rawbufs[0] = type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype) + var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)} + if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": + print(f"{lin.applied_opts=}") + return msg + + result = rawbufs[0].toCPU() + if ground_truth is None: + ground_truth = result + else: + try: + np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2) + except AssertionError: + print(lin.ast) + traceback.print_exc() + print(f"{lin.applied_opts=}") + return "NOT_ALLCLOSE" + return "PASS" if __name__ == "__main__": - ast_strs = load_worlds() - print(f"{len(ast_strs)=}") - tested = 0 - c = Counter() - failed = [] - for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]): - if "Variable" in ast and isinstance(device, Interpreted): continue # no symbolic shape for Interpreted - if "dtypes.image" in ast and Device.DEFAULT != "GPU": continue # IMAGE is only for GPU - print(f"testing ast {i}") - tested += 1 - lin = ast_str_to_lin(ast) - fuzz = str(fuzz_linearizer(lin)) - c[fuzz] += 1 - if fuzz != "PASS": - failed.append(i) - print(f"{tested=}") - print(c.most_common()) - print(f"{failed=}") \ No newline at end of file + ast_strs = load_worlds() + print(f"{len(ast_strs)=}") + tested = 0 + c = Counter() + failed = [] + for i, ast in enumerate(ast_strs[: getenv("FUZZ_N", len(ast_strs))]): + if "Variable" in ast and isinstance(device, Interpreted): + continue # no symbolic shape for Interpreted + if "dtypes.image" in ast and Device.DEFAULT != "GPU": + continue # IMAGE is only for GPU + print(f"testing ast {i}") + tested += 1 + lin = ast_str_to_lin(ast) + fuzz = str(fuzz_linearizer(lin)) + c[fuzz] += 1 + if fuzz != "PASS": + failed.append(i) + print(f"{tested=}") + print(c.most_common()) + print(f"{failed=}") diff --git a/test/external/fuzz_shapetracker.py b/test/external/fuzz_shapetracker.py index 4e89dad8c..4fd4d9c6d 100644 --- a/test/external/fuzz_shapetracker.py +++ b/test/external/fuzz_shapetracker.py @@ -1,61 +1,101 @@ import random from tinygrad.helpers import DEBUG, getenv from test.unit.test_shapetracker import CheckingShapeTracker + random.seed(42) + def do_permute(st): - perm = list(range(0, len(st.shape))) - random.shuffle(perm) - perm = tuple(perm) - if DEBUG >= 1: print("st.permute(", perm, ")") - st.permute(perm) + perm = list(range(0, len(st.shape))) + random.shuffle(perm) + perm = tuple(perm) + if DEBUG >= 1: + print("st.permute(", perm, ")") + st.permute(perm) + def do_pad(st): - c = random.randint(0, len(st.shape)-1) - pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape))) - if DEBUG >= 1: print("st.pad(", pad, ")") - st.pad(pad) + c = random.randint(0, len(st.shape) - 1) + pad = tuple( + (random.randint(0, 2), random.randint(0, 2)) if i == c else (0, 0) + for i in range(len(st.shape)) + ) + if DEBUG >= 1: + print("st.pad(", pad, ")") + st.pad(pad) + def do_reshape_split_one(st): - c = random.randint(0, len(st.shape)-1) - poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0] - spl = random.choice(poss) - shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:] - if DEBUG >= 1: print("st.reshape(", shp, ")") - st.reshape(shp) + c = random.randint(0, len(st.shape) - 1) + poss = [n for n in [1, 2, 3, 4, 5] if st.shape[c] % n == 0] + spl = random.choice(poss) + shp = st.shape[0:c] + (st.shape[c] // spl, spl) + st.shape[c + 1 :] + if DEBUG >= 1: + print("st.reshape(", shp, ")") + st.reshape(shp) + def do_reshape_combine_two(st): - if len(st.shape) < 2: return - c = random.randint(0, len(st.shape)-2) - shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:] - if DEBUG >= 1: print("st.reshape(", shp, ")") - st.reshape(shp) + if len(st.shape) < 2: + return + c = random.randint(0, len(st.shape) - 2) + shp = st.shape[:c] + (st.shape[c] * st.shape[c + 1],) + st.shape[c + 2 :] + if DEBUG >= 1: + print("st.reshape(", shp, ")") + st.reshape(shp) + def do_shrink(st): - c = random.randint(0, len(st.shape)-1) - while 1: - shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape)) - if all(x= 1: print("st.shrink(", shrink, ")") - st.shrink(shrink) + c = random.randint(0, len(st.shape) - 1) + while 1: + shrink = tuple( + (random.randint(0, s), random.randint(0, s)) if i == c else (0, s) + for i, s in enumerate(st.shape) + ) + if all(x < y for (x, y) in shrink): + break + if DEBUG >= 1: + print("st.shrink(", shrink, ")") + st.shrink(shrink) + def do_stride(st): - c = random.randint(0, len(st.shape)-1) - stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape))) - if DEBUG >= 1: print("st.stride(", stride, ")") - st.stride(stride) + c = random.randint(0, len(st.shape) - 1) + stride = tuple( + random.choice([-2, -1, 2]) if i == c else 1 for i in range(len(st.shape)) + ) + if DEBUG >= 1: + print("st.stride(", stride, ")") + st.stride(stride) + def do_expand(st): - c = [i for i,s in enumerate(st.shape) if s==1] - if len(c) == 0: return - c = random.choice(c) - expand = tuple(random.choice([2,3,4]) if i==c else s for i,s in enumerate(st.shape)) - if DEBUG >= 1: print("st.expand(", expand, ")") - st.expand(expand) + c = [i for i, s in enumerate(st.shape) if s == 1] + if len(c) == 0: + return + c = random.choice(c) + expand = tuple( + random.choice([2, 3, 4]) if i == c else s for i, s in enumerate(st.shape) + ) + if DEBUG >= 1: + print("st.expand(", expand, ")") + st.expand(expand) + if __name__ == "__main__": - ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand] - for _ in range(getenv("CNT", 200)): - st = CheckingShapeTracker((random.randint(2, 10), random.randint(2, 10), random.randint(2, 10))) - for i in range(8): random.choice(ops)(st) - st.assert_same() + ops = [ + do_permute, + do_pad, + do_shrink, + do_reshape_split_one, + do_reshape_combine_two, + do_stride, + do_expand, + ] + for _ in range(getenv("CNT", 200)): + st = CheckingShapeTracker( + (random.randint(2, 10), random.randint(2, 10), random.randint(2, 10)) + ) + for i in range(8): + random.choice(ops)(st) + st.assert_same() diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index d17191413..612600270 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -2,68 +2,101 @@ import itertools import random from tinygrad.helpers import DEBUG from tinygrad.shape.symbolic import Variable, NumNode + random.seed(42) + def add_v(expr, rng=None): - if rng is None: rng = random.randint(0,2) - return expr + v[rng], rng + if rng is None: + rng = random.randint(0, 2) + return expr + v[rng], rng + def div(expr, rng=None): - if rng is None: rng = random.randint(1,9) - return expr // rng, rng + if rng is None: + rng = random.randint(1, 9) + return expr // rng, rng + def mul(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr * rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr * rng, rng + def mod(expr, rng=None): - if rng is None: rng = random.randint(1,9) - return expr % rng, rng + if rng is None: + rng = random.randint(1, 9) + return expr % rng, rng + def add_num(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr + rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr + rng, rng + def lt(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr < rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr < rng, rng + def ge(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr >= rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr >= rng, rng + def le(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr <= rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr <= rng, rng + def gt(expr, rng=None): - if rng is None: rng = random.randint(-4,4) - return expr > rng, rng + if rng is None: + rng = random.randint(-4, 4) + return expr > rng, rng + if __name__ == "__main__": - ops = [add_v, div, mul, add_num, mod] - for _ in range(1000): - upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256] - u1 = Variable("v1", 0, random.choice(upper_bounds)) - u2 = Variable("v2", 0, random.choice(upper_bounds)) - u3 = Variable("v3", 0, random.choice(upper_bounds)) - v = [u1,u2,u3] - tape = [random.choice(ops) for _ in range(random.randint(2, 30))] - # 10% of the time, add one of lt, le, gt, ge - if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge])) - expr = NumNode(0) - rngs = [] - for t in tape: - expr, rng = t(expr) - if DEBUG >= 1: print(t.__name__, rng) - rngs.append(rng) - if DEBUG >=1: print(expr) - space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1))) - volume = len(space) - for (v1, v2, v3) in random.sample(space, min(100, volume)): - v = [v1,v2,v3] - rn = 0 - for t,r in zip(tape, rngs): rn, _ = t(rn, r) - num = eval(expr.render()) - assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}" - if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}") + ops = [add_v, div, mul, add_num, mod] + for _ in range(1000): + upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256] + u1 = Variable("v1", 0, random.choice(upper_bounds)) + u2 = Variable("v2", 0, random.choice(upper_bounds)) + u3 = Variable("v3", 0, random.choice(upper_bounds)) + v = [u1, u2, u3] + tape = [random.choice(ops) for _ in range(random.randint(2, 30))] + # 10% of the time, add one of lt, le, gt, ge + if random.random() < 0.1: + tape.append(random.choice([lt, le, gt, ge])) + expr = NumNode(0) + rngs = [] + for t in tape: + expr, rng = t(expr) + if DEBUG >= 1: + print(t.__name__, rng) + rngs.append(rng) + if DEBUG >= 1: + print(expr) + space = list( + itertools.product( + range(u1.min, u1.max + 1), + range(u2.min, u2.max + 1), + range(u3.min, u3.max + 1), + ) + ) + volume = len(space) + for v1, v2, v3 in random.sample(space, min(100, volume)): + v = [v1, v2, v3] + rn = 0 + for t, r in zip(tape, rngs): + rn, _ = t(rn, r) + num = eval(expr.render()) + assert ( + num == rn + ), f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}" + if DEBUG >= 1: + print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}") diff --git a/test/external/graph_batchnorm.py b/test/external/graph_batchnorm.py index 59e3b7961..7d1bfcdbb 100644 --- a/test/external/graph_batchnorm.py +++ b/test/external/graph_batchnorm.py @@ -3,59 +3,69 @@ from tinygrad.nn.state import get_parameters from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, BatchNorm2d, optim + def model_step(lm): - with Tensor.train(): - x = Tensor.ones(8,12,128,256, requires_grad=False) - optimizer = optim.SGD(get_parameters(lm), lr=0.001) - loss = lm.forward(x).sum() - optimizer.zero_grad() - loss.backward() - del x,loss - optimizer.step() + with Tensor.train(): + x = Tensor.ones(8, 12, 128, 256, requires_grad=False) + optimizer = optim.SGD(get_parameters(lm), lr=0.001) + loss = lm.forward(x).sum() + optimizer.zero_grad() + loss.backward() + del x, loss + optimizer.step() + class TestBatchnorm(unittest.TestCase): - def test_conv(self): - class LilModel: - def __init__(self): - self.c = Conv2d(12, 32, 3, padding=1, bias=False) - def forward(self, x): - return self.c(x).relu() - lm = LilModel() - model_step(lm) + def test_conv(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 32, 3, padding=1, bias=False) - def test_two_conv(self): - class LilModel: - def __init__(self): - self.c = Conv2d(12, 32, 3, padding=1, bias=False) - self.c2 = Conv2d(32, 32, 3, padding=1, bias=False) - def forward(self, x): - return self.c2(self.c(x)).relu() - lm = LilModel() - model_step(lm) + def forward(self, x): + return self.c(x).relu() - def test_two_conv_bn(self): - class LilModel: - def __init__(self): - self.c = Conv2d(12, 24, 3, padding=1, bias=False) - self.bn = BatchNorm2d(24, track_running_stats=False) - self.c2 = Conv2d(24, 32, 3, padding=1, bias=False) - self.bn2 = BatchNorm2d(32, track_running_stats=False) - def forward(self, x): - x = self.bn(self.c(x)).relu() - return self.bn2(self.c2(x)).relu() - lm = LilModel() - model_step(lm) + lm = LilModel() + model_step(lm) - def test_conv_bn(self): - class LilModel: - def __init__(self): - self.c = Conv2d(12, 32, 3, padding=1, bias=False) - self.bn = BatchNorm2d(32, track_running_stats=False) - def forward(self, x): - return self.bn(self.c(x)).relu() - lm = LilModel() - model_step(lm) + def test_two_conv(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 32, 3, padding=1, bias=False) + self.c2 = Conv2d(32, 32, 3, padding=1, bias=False) + + def forward(self, x): + return self.c2(self.c(x)).relu() + + lm = LilModel() + model_step(lm) + + def test_two_conv_bn(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 24, 3, padding=1, bias=False) + self.bn = BatchNorm2d(24, track_running_stats=False) + self.c2 = Conv2d(24, 32, 3, padding=1, bias=False) + self.bn2 = BatchNorm2d(32, track_running_stats=False) + + def forward(self, x): + x = self.bn(self.c(x)).relu() + return self.bn2(self.c2(x)).relu() + + lm = LilModel() + model_step(lm) + + def test_conv_bn(self): + class LilModel: + def __init__(self): + self.c = Conv2d(12, 32, 3, padding=1, bias=False) + self.bn = BatchNorm2d(32, track_running_stats=False) + + def forward(self, x): + return self.bn(self.c(x)).relu() + + lm = LilModel() + model_step(lm) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/external/test_example.py b/test/external/test_example.py index f11fb74a9..418c51679 100644 --- a/test/external/test_example.py +++ b/test/external/test_example.py @@ -3,71 +3,80 @@ from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, CI + def multidevice_test(fxn): - exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",") - def ret(self): - for device in Device._buffers: - if device in ["DISK", "FAKE"]: continue - if not CI: print(device) - if device in exclude_devices: - if not CI: print(f"WARNING: {device} test is excluded") - continue - with self.subTest(device=device): - try: - Device[device] - except Exception: - if not CI: print(f"WARNING: {device} test isn't running") - continue - fxn(self, device) - return ret + exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",") + + def ret(self): + for device in Device._buffers: + if device in ["DISK", "FAKE"]: + continue + if not CI: + print(device) + if device in exclude_devices: + if not CI: + print(f"WARNING: {device} test is excluded") + continue + with self.subTest(device=device): + try: + Device[device] + except Exception: + if not CI: + print(f"WARNING: {device} test isn't running") + continue + fxn(self, device) + + return ret + class TestExample(unittest.TestCase): - @multidevice_test - def test_convert_to_cpu(self, device): - a = Tensor([[1,2],[3,4]], device=device) - assert a.numpy().shape == (2,2) - b = a.cpu() - assert b.numpy().shape == (2,2) + @multidevice_test + def test_convert_to_cpu(self, device): + a = Tensor([[1, 2], [3, 4]], device=device) + assert a.numpy().shape == (2, 2) + b = a.cpu() + assert b.numpy().shape == (2, 2) - @multidevice_test - def test_2_plus_3(self, device): - a = Tensor([2], device=device) - b = Tensor([3], device=device) - result = a + b - print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}") - assert result.numpy()[0] == 5. + @multidevice_test + def test_2_plus_3(self, device): + a = Tensor([2], device=device) + b = Tensor([3], device=device) + result = a + b + print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}") + assert result.numpy()[0] == 5.0 - @multidevice_test - def test_example_readme(self, device): - x = Tensor.eye(3, device=device, requires_grad=True) - y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True) - z = y.matmul(x).sum() - z.backward() + @multidevice_test + def test_example_readme(self, device): + x = Tensor.eye(3, device=device, requires_grad=True) + y = Tensor([[2.0, 0, -2.0]], device=device, requires_grad=True) + z = y.matmul(x).sum() + z.backward() - x.grad.numpy() # dz/dx - y.grad.numpy() # dz/dy + x.grad.numpy() # dz/dx + y.grad.numpy() # dz/dy - assert x.grad.device == device - assert y.grad.device == device + assert x.grad.device == device + assert y.grad.device == device - @multidevice_test - def test_example_matmul(self, device): - try: - Device[device] - except Exception: - print(f"WARNING: {device} test isn't running") - return + @multidevice_test + def test_example_matmul(self, device): + try: + Device[device] + except Exception: + print(f"WARNING: {device} test isn't running") + return - x = Tensor.eye(64, device=device, requires_grad=True) - y = Tensor.eye(64, device=device, requires_grad=True) - z = y.matmul(x).sum() - z.backward() + x = Tensor.eye(64, device=device, requires_grad=True) + y = Tensor.eye(64, device=device, requires_grad=True) + z = y.matmul(x).sum() + z.backward() - x.grad.numpy() # dz/dx - y.grad.numpy() # dz/dy + x.grad.numpy() # dz/dx + y.grad.numpy() # dz/dy - assert x.grad.device == device - assert y.grad.device == device + assert x.grad.device == device + assert y.grad.device == device -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/extra/test_export_model.py b/test/extra/test_export_model.py index 675e46d09..2fd9a016d 100644 --- a/test/extra/test_export_model.py +++ b/test/extra/test_export_model.py @@ -3,48 +3,66 @@ from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE from tinygrad.tensor import Tensor, Device import json + class MockMultiInputModel: - def forward(self, x1, x2, x3): - return x1 + x2 + x3 + def forward(self, x1, x2, x3): + return x1 + x2 + x3 + class MockMultiOutputModel: - def __call__(self, x1): - return x1 + 2.0, x1.pad(((0, 0), (0, 1))) + 1.0 + def __call__(self, x1): + return x1 + 2.0, x1.pad(((0, 0), (0, 1))) + 1.0 + # TODO: move compile_efficientnet tests here -@unittest.skipUnless(Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"Model export is not supported on {Device.DEFAULT}") +@unittest.skipUnless( + Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, + f"Model export is not supported on {Device.DEFAULT}", +) class TextModelExport(unittest.TestCase): - def test_multi_input_model_export(self): - model = MockMultiInputModel() - inputs = [Tensor.rand(2,2), Tensor.rand(2,2), Tensor.rand(2,2)] - prg, inp_sizes, _, _ = export_model(model, "", *inputs) - prg = json.loads(prg) + def test_multi_input_model_export(self): + model = MockMultiInputModel() + inputs = [Tensor.rand(2, 2), Tensor.rand(2, 2), Tensor.rand(2, 2)] + prg, inp_sizes, _, _ = export_model(model, "", *inputs) + prg = json.loads(prg) - assert len(inputs) == len(prg["inputs"]) == len(inp_sizes), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}" + assert ( + len(inputs) == len(prg["inputs"]) == len(inp_sizes) + ), f"Model and exported inputs don't match: mdl={len(inputs)}, prg={len(prg['inputs'])}, inp_sizes={len(inp_sizes)}" - for i in range(len(inputs)): - assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes" - assert f"input{i}" in prg["buffers"], f"input{i} not captured in exported buffers" + for i in range(len(inputs)): + assert f"input{i}" in inp_sizes, f"input{i} not captured in inp_sizes" + assert ( + f"input{i}" in prg["buffers"] + ), f"input{i} not captured in exported buffers" - for i, exported_input in enumerate(prg["inputs"]): - assert inputs[i].dtype.name == exported_input["dtype"], f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}" + for i, exported_input in enumerate(prg["inputs"]): + assert ( + inputs[i].dtype.name == exported_input["dtype"] + ), f"Model and exported input dtype don't match: mdl={inputs[i].dtype.name}, prg={exported_input['dtype']}" - def test_multi_output_model_export(self): - model = MockMultiOutputModel() - input = Tensor.rand(2,2) - outputs = model(input) - prg, _, out_sizes, _ = export_model(model, "", input) - prg = json.loads(prg) + def test_multi_output_model_export(self): + model = MockMultiOutputModel() + input = Tensor.rand(2, 2) + outputs = model(input) + prg, _, out_sizes, _ = export_model(model, "", input) + prg = json.loads(prg) - assert len(outputs) == len(prg["outputs"]) == len(out_sizes), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}" + assert ( + len(outputs) == len(prg["outputs"]) == len(out_sizes) + ), f"Model and exported outputs don't match: mdl={len(outputs)}, prg={len(prg['outputs'])}, inp_sizes={len(out_sizes)}" - for i in range(len(outputs)): - assert f"output{i}" in out_sizes, f"output{i} not captured in out_sizes" - assert f"output{i}" in prg["buffers"], f"output{i} not captured in exported buffers" + for i in range(len(outputs)): + assert f"output{i}" in out_sizes, f"output{i} not captured in out_sizes" + assert ( + f"output{i}" in prg["buffers"] + ), f"output{i} not captured in exported buffers" - for i, exported_output in enumerate(prg["outputs"]): - assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" + for i, exported_output in enumerate(prg["outputs"]): + assert ( + outputs[i].dtype.name == exported_output["dtype"] + ), f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/extra/test_extra_helpers.py b/test/extra/test_extra_helpers.py index 6832b973c..9fc69f84e 100644 --- a/test/extra/test_extra_helpers.py +++ b/test/extra/test_extra_helpers.py @@ -2,56 +2,74 @@ import os, cloudpickle, tempfile, unittest, subprocess from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper -def normalize_line_endings(s): return s.replace(b'\r\n', b'\n') + +def normalize_line_endings(s): + return s.replace(b"\r\n", b"\n") + class TestEarlyExec(unittest.TestCase): - def setUp(self) -> None: - self.early_exec = enable_early_exec() + def setUp(self) -> None: + self.early_exec = enable_early_exec() - def early_exec_py_file(self, file_content, exec_args): - with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp: - temp.write(file_content) - temp_path = temp.name - try: - output = self.early_exec((["python3", temp_path] + exec_args, None)) - return output - finally: - os.remove(temp_path) + def early_exec_py_file(self, file_content, exec_args): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp: + temp.write(file_content) + temp_path = temp.name + try: + output = self.early_exec((["python3", temp_path] + exec_args, None)) + return output + finally: + os.remove(temp_path) - def test_enable_early_exec(self): - output = self.early_exec_py_file(b'print("Hello, world!")', []) - self.assertEqual(b"Hello, world!\n", normalize_line_endings(output)) + def test_enable_early_exec(self): + output = self.early_exec_py_file(b'print("Hello, world!")', []) + self.assertEqual(b"Hello, world!\n", normalize_line_endings(output)) - def test_enable_early_exec_with_arg(self): - output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"]) - self.assertEqual(b"Hello, world!\n", normalize_line_endings(output)) + def test_enable_early_exec_with_arg(self): + output = self.early_exec_py_file( + b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"] + ) + self.assertEqual(b"Hello, world!\n", normalize_line_endings(output)) - def test_enable_early_exec_process_exception(self): - with self.assertRaises(subprocess.CalledProcessError): - self.early_exec_py_file(b'raise Exception("Test exception")', []) + def test_enable_early_exec_process_exception(self): + with self.assertRaises(subprocess.CalledProcessError): + self.early_exec_py_file(b'raise Exception("Test exception")', []) + + def test_enable_early_exec_type_exception(self): + with self.assertRaises(TypeError): + self.early_exec((["python3"], "print('Hello, world!')")) - def test_enable_early_exec_type_exception(self): - with self.assertRaises(TypeError): - self.early_exec((["python3"], "print('Hello, world!')")) class TestCrossProcess(unittest.TestCase): + def test_cross_process(self): + def _iterate(): + for i in range(10): + yield i - def test_cross_process(self): - def _iterate(): - for i in range(10): yield i - results = list(cross_process(_iterate)) - self.assertEqual(list(range(10)), results) + results = list(cross_process(_iterate)) + self.assertEqual(list(range(10)), results) - def test_cross_process_exception(self): - def _iterate(): - for i in range(10): - if i == 5: raise ValueError("Test exception") - yield i - with self.assertRaises(ValueError): list(cross_process(_iterate)) + def test_cross_process_exception(self): + def _iterate(): + for i in range(10): + if i == 5: + raise ValueError("Test exception") + yield i - def test_CloudpickleFunctionWrapper(self): - def add(x, y): return x + y - self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4)) + with self.assertRaises(ValueError): + list(cross_process(_iterate)) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def test_CloudpickleFunctionWrapper(self): + def add(x, y): + return x + y + + self.assertEqual( + 7, + cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))( + 3, 4 + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/extra/test_lr_scheduler.py b/test/extra/test_lr_scheduler.py index 3bff36600..65f397315 100644 --- a/test/extra/test_lr_scheduler.py +++ b/test/extra/test_lr_scheduler.py @@ -4,7 +4,12 @@ import unittest from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters from tinygrad.nn.optim import Adam -from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR +from extra.lr_scheduler import ( + MultiStepLR, + ReduceLROnPlateau, + CosineAnnealingLR, + OneCycleLR, +) from extra.training import train, evaluate from extra.datasets import fetch_mnist import pytest @@ -16,94 +21,175 @@ Tensor.manual_seed(1337) X_train, Y_train, X_test, Y_test = fetch_mnist() + class TinyBobNet: - def __init__(self): - self.l1 = Tensor.scaled_uniform(784, 128) - self.l2 = Tensor.scaled_uniform(128, 10) + def __init__(self): + self.l1 = Tensor.scaled_uniform(784, 128) + self.l2 = Tensor.scaled_uniform(128, 10) - def parameters(self): - return get_parameters(self) + def parameters(self): + return get_parameters(self) + + def forward(self, x): + return x.dot(self.l1).relu().dot(self.l2).log_softmax() - def forward(self, x): - return x.dot(self.l1).relu().dot(self.l2).log_softmax() def lr_scheduler_training(sched_fn=None, args=None): - model = TinyBobNet() - optim = Adam(model.parameters(), lr=0.01) - if sched_fn is not None: sched = sched_fn(optim, **args) - for _ in range(25): - train(model, X_train, Y_train, optim, 100) + model = TinyBobNet() + optim = Adam(model.parameters(), lr=0.01) if sched_fn is not None: - if isinstance(sched, ReduceLROnPlateau): - sched.step(evaluate(model, X_test, Y_test)) - else: - sched.step() - return evaluate(model, X_test, Y_test) + sched = sched_fn(optim, **args) + for _ in range(25): + train(model, X_train, Y_train, optim, 100) + if sched_fn is not None: + if isinstance(sched, ReduceLROnPlateau): + sched.step(evaluate(model, X_test, Y_test)) + else: + sched.step() + return evaluate(model, X_test, Y_test) + + +def current_lr(optim): + return optim.param_groups[0]["lr"] if hasattr(optim, "param_groups") else optim.lr + -def current_lr(optim): return optim.param_groups[0]['lr'] if hasattr(optim, 'param_groups') else optim.lr def get_lrs(optim, sched, epochs, steps=1, accs=None): - lr = current_lr(optim) - if not isinstance(lr, float): lr = lr.numpy()[0] - lrs = [lr] - for e in range(epochs): - for _ in range(steps): - optim.step() - sched.step() if accs is None else sched.step(accs[e]) lr = current_lr(optim) - if not isinstance(lr, float): lr = lr.numpy()[0] - lrs.append(lr) - return lrs + if not isinstance(lr, float): + lr = lr.numpy()[0] + lrs = [lr] + for e in range(epochs): + for _ in range(steps): + optim.step() + sched.step() if accs is None else sched.step(accs[e]) + lr = current_lr(optim) + if not isinstance(lr, float): + lr = lr.numpy()[0] + lrs.append(lr) + return lrs + class TestLrScheduler(unittest.TestCase): - def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol): - accs = opts.pop('accs', None) - test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] - test_tensor.mean().backward() - tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) - tinygrad_sched, torch_sched = tinygrad_sched(tinygrad_optim, **opts), torch_sched(torch_optim, **opts) + def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol): + accs = opts.pop("accs", None) + test_tensor = Tensor( + [0], requires_grad=True + ) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] + test_tensor.mean().backward() + tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam( + [torch.tensor([0.0], requires_grad=True)], lr=0.01 + ) + tinygrad_sched, torch_sched = tinygrad_sched( + tinygrad_optim, **opts + ), torch_sched(torch_optim, **opts) - tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs) - torch_lrs = get_lrs(torch_optim, torch_sched, epochs, accs=accs) + tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs) + torch_lrs = get_lrs(torch_optim, torch_sched, epochs, accs=accs) - np.testing.assert_allclose(tinygrad_lrs, torch_lrs, atol=atol, rtol=rtol) + np.testing.assert_allclose(tinygrad_lrs, torch_lrs, atol=atol, rtol=rtol) - def _test_multisteplr(self, epochs, opts, atol, rtol): - self._test_lr_scheduler(MultiStepLR, torch.optim.lr_scheduler.MultiStepLR, epochs, opts, atol, rtol) - def _test_reducelronplateau(self, epochs, opts, atol, rtol): - opts['accs'] = np.random.randn(epochs) - self._test_lr_scheduler(ReduceLROnPlateau, torch.optim.lr_scheduler.ReduceLROnPlateau, epochs, opts, atol, rtol) - def _test_cosineannealinglr(self, epochs, opts, atol, rtol): - opts['T_max'] = epochs - self._test_lr_scheduler(CosineAnnealingLR, torch.optim.lr_scheduler.CosineAnnealingLR, epochs, opts, atol, rtol) - def _test_onecyclelr(self, epochs, opts, atol, rtol): - opts['total_steps'] = epochs - self._test_lr_scheduler(OneCycleLR, torch.optim.lr_scheduler.OneCycleLR, epochs, opts, atol, rtol) + def _test_multisteplr(self, epochs, opts, atol, rtol): + self._test_lr_scheduler( + MultiStepLR, torch.optim.lr_scheduler.MultiStepLR, epochs, opts, atol, rtol + ) - def test_multisteplr(self): self._test_multisteplr(10, {'milestones': [1, 2, 7]}, 1e-6, 1e-6) - def test_multisteplr_gamma(self): self._test_multisteplr(10, {'milestones': [1, 2, 7], 'gamma': 0.1337}, 1e-6, 1e-6) + def _test_reducelronplateau(self, epochs, opts, atol, rtol): + opts["accs"] = np.random.randn(epochs) + self._test_lr_scheduler( + ReduceLROnPlateau, + torch.optim.lr_scheduler.ReduceLROnPlateau, + epochs, + opts, + atol, + rtol, + ) - def test_reducelronplateau(self): self._test_reducelronplateau(100, {}, 1e-6, 1e-6) - def test_reducelronplateau_max(self): self._test_reducelronplateau(100, {'mode': 'max'}, 1e-6, 1e-6) - def test_reducelronplateau_factor(self): self._test_reducelronplateau(100, {'factor': 0.1337}, 1e-6, 1e-6) - def test_reducelronplateau_patience(self): self._test_reducelronplateau(100, {'patience': 3}, 1e-6, 1e-6) - def test_reducelronplateau_threshold(self): self._test_reducelronplateau(100, {'threshold': 1e-6}, 1e-6, 1e-6) - def test_reducelronplateau_threshold_mode(self): self._test_reducelronplateau(100, {'threshold_mode': 'abs'}, 1e-6, 1e-6) + def _test_cosineannealinglr(self, epochs, opts, atol, rtol): + opts["T_max"] = epochs + self._test_lr_scheduler( + CosineAnnealingLR, + torch.optim.lr_scheduler.CosineAnnealingLR, + epochs, + opts, + atol, + rtol, + ) - def test_cosineannealinglr(self): self._test_cosineannealinglr(100, {}, 1e-6, 1e-6) - def test_cosineannealinglr_eta_min(self): self._test_cosineannealinglr(100, {'eta_min': 0.001}, 1e-6, 1e-6) + def _test_onecyclelr(self, epochs, opts, atol, rtol): + opts["total_steps"] = epochs + self._test_lr_scheduler( + OneCycleLR, torch.optim.lr_scheduler.OneCycleLR, epochs, opts, atol, rtol + ) - def test_onecyclelr(self): self._test_onecyclelr(1000, {'pct_start': 0.3, 'anneal_strategy': 'linear', - 'cycle_momentum': False, 'div_factor': 25.0, - 'final_div_factor': 10000.0, 'max_lr':1e-5}, 1e-6, 1e-6) - @unittest.skip("slow") - def test_training(self): - without = lr_scheduler_training() - sched_fns = [MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR] - argss = [{'milestones': [5, 7, 10, 15], 'gamma': 0.5}, {'factor': 0.5, 'patience': 2}, {'T_max': 25, 'eta_min': 0.001}, - {'pct_start': 0.3, 'anneal_strategy': 'linear', 'cycle_momentum': False, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'max_lr':1e-5, 'total_steps': 25}] - for sched_fn, args in zip(sched_fns, argss): - with_sched = lr_scheduler_training(sched_fn, args) - assert with_sched > without + def test_multisteplr(self): + self._test_multisteplr(10, {"milestones": [1, 2, 7]}, 1e-6, 1e-6) -if __name__ == '__main__': - unittest.main() + def test_multisteplr_gamma(self): + self._test_multisteplr( + 10, {"milestones": [1, 2, 7], "gamma": 0.1337}, 1e-6, 1e-6 + ) + + def test_reducelronplateau(self): + self._test_reducelronplateau(100, {}, 1e-6, 1e-6) + + def test_reducelronplateau_max(self): + self._test_reducelronplateau(100, {"mode": "max"}, 1e-6, 1e-6) + + def test_reducelronplateau_factor(self): + self._test_reducelronplateau(100, {"factor": 0.1337}, 1e-6, 1e-6) + + def test_reducelronplateau_patience(self): + self._test_reducelronplateau(100, {"patience": 3}, 1e-6, 1e-6) + + def test_reducelronplateau_threshold(self): + self._test_reducelronplateau(100, {"threshold": 1e-6}, 1e-6, 1e-6) + + def test_reducelronplateau_threshold_mode(self): + self._test_reducelronplateau(100, {"threshold_mode": "abs"}, 1e-6, 1e-6) + + def test_cosineannealinglr(self): + self._test_cosineannealinglr(100, {}, 1e-6, 1e-6) + + def test_cosineannealinglr_eta_min(self): + self._test_cosineannealinglr(100, {"eta_min": 0.001}, 1e-6, 1e-6) + + def test_onecyclelr(self): + self._test_onecyclelr( + 1000, + { + "pct_start": 0.3, + "anneal_strategy": "linear", + "cycle_momentum": False, + "div_factor": 25.0, + "final_div_factor": 10000.0, + "max_lr": 1e-5, + }, + 1e-6, + 1e-6, + ) + + @unittest.skip("slow") + def test_training(self): + without = lr_scheduler_training() + sched_fns = [MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR] + argss = [ + {"milestones": [5, 7, 10, 15], "gamma": 0.5}, + {"factor": 0.5, "patience": 2}, + {"T_max": 25, "eta_min": 0.001}, + { + "pct_start": 0.3, + "anneal_strategy": "linear", + "cycle_momentum": False, + "div_factor": 25.0, + "final_div_factor": 10000.0, + "max_lr": 1e-5, + "total_steps": 25, + }, + ] + for sched_fn, args in zip(sched_fns, argss): + with_sched = lr_scheduler_training(sched_fn, args) + assert with_sched > without + + +if __name__ == "__main__": + unittest.main() diff --git a/test/helpers.py b/test/helpers.py index 6f31dd194..555bff39c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -2,25 +2,32 @@ from tinygrad.device import JITRunner from tinygrad.ops import LazyOp, LoadOps from tinygrad.nn.state import get_parameters + # for speed def derandomize(x): - if isinstance(x, LazyOp): - new_op = LoadOps.EMPTY if x.op == LoadOps.CUSTOM else x.op - return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.CUSTOM else x.arg) - x.op = derandomize(x.op) - return x + if isinstance(x, LazyOp): + new_op = LoadOps.EMPTY if x.op == LoadOps.CUSTOM else x.op + return LazyOp( + new_op, + tuple([derandomize(s) for s in x.src]), + None if x.op == LoadOps.CUSTOM else x.arg, + ) + x.op = derandomize(x.op) + return x + def derandomize_model(model): - for p in get_parameters(model): - p.lazydata = derandomize(p.lazydata) - p.realize() + for p in get_parameters(model): + p.lazydata = derandomize(p.lazydata) + p.realize() + def assert_jit_cache_len(fxn, expected_len): - assert len(fxn.jit_cache) > 0 - if issubclass(type(fxn.jit_cache[0].prg), JITRunner): - assert len(fxn.jit_cache) == expected_len - else: - assert len(fxn.jit_cache) == 1 - # until we have a better way of typing the prg in JitItem - assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') - assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len \ No newline at end of file + assert len(fxn.jit_cache) > 0 + if issubclass(type(fxn.jit_cache[0].prg), JITRunner): + assert len(fxn.jit_cache) == expected_len + else: + assert len(fxn.jit_cache) == 1 + # until we have a better way of typing the prg in JitItem + assert type(fxn.jit_cache[0].prg).__name__.endswith("Graph") + assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 3bed0c47c..df9ced33f 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -7,1409 +7,1467 @@ from tinygrad.tensor import Tensor random.seed(42) + def numpy_testing_assert_equal_helper(a, b): - if isinstance(a, Tensor): a = a.numpy() - if isinstance(b, Tensor): b = b.numpy() - np.testing.assert_equal(a, b) + if isinstance(a, Tensor): + a = a.numpy() + if isinstance(b, Tensor): + b = b.numpy() + np.testing.assert_equal(a, b) + def consec(shape, start=1): - return Tensor(np.arange(math.prod(shape)).reshape(shape)+start) + return Tensor(np.arange(math.prod(shape)).reshape(shape) + start) + class TestIndexing(unittest.TestCase): - def test_index(self): - - reference = consec((3, 3, 3)) - - numpy_testing_assert_equal_helper(reference[0], consec((3, 3))) - numpy_testing_assert_equal_helper(reference[1], consec((3, 3), 10)) - numpy_testing_assert_equal_helper(reference[2], consec((3, 3), 19)) - numpy_testing_assert_equal_helper(reference[0, 1], consec((3,), 4)) - numpy_testing_assert_equal_helper(reference[0:2], consec((2, 3, 3))) - numpy_testing_assert_equal_helper(reference[2, 2, 2], 27) - numpy_testing_assert_equal_helper(reference[:], consec((3, 3, 3))) - - # indexing with Ellipsis - numpy_testing_assert_equal_helper(reference[..., 2], np.array([[3., 6., 9.],[12., 15., 18.],[21., 24., 27.]])) - numpy_testing_assert_equal_helper(reference[0, ..., 2], np.array([3., 6., 9.])) - numpy_testing_assert_equal_helper(reference[..., 2], reference[:, :, 2]) - numpy_testing_assert_equal_helper(reference[0, ..., 2], reference[0, :, 2]) - numpy_testing_assert_equal_helper(reference[0, 2, ...], reference[0, 2]) - numpy_testing_assert_equal_helper(reference[..., 2, 2, 2], 27) - numpy_testing_assert_equal_helper(reference[2, ..., 2, 2], 27) - numpy_testing_assert_equal_helper(reference[2, 2, ..., 2], 27) - numpy_testing_assert_equal_helper(reference[2, 2, 2, ...], 27) - numpy_testing_assert_equal_helper(reference[...], reference) - - reference_5d = consec((3, 3, 3, 3, 3)) - numpy_testing_assert_equal_helper(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0]) - numpy_testing_assert_equal_helper(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0]) - numpy_testing_assert_equal_helper(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1]) - numpy_testing_assert_equal_helper(reference_5d[...], reference_5d) - - # None indexing - numpy_testing_assert_equal_helper(reference[2, None], reference[2].unsqueeze(0)) - numpy_testing_assert_equal_helper(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)) - numpy_testing_assert_equal_helper(reference[2:4, None], reference[2:4].unsqueeze(1)) - numpy_testing_assert_equal_helper(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0)) - numpy_testing_assert_equal_helper(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2)) - - # indexing 0-length slice - numpy_testing_assert_equal_helper(np.empty((0, 3, 3)), reference[slice(0)]) - numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[slice(0), 2]) - numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[2, slice(0)]) - numpy_testing_assert_equal_helper(np.empty([]), reference[2, 1:1, 2]) - - # indexing with step - reference = consec((10, 10, 10)) - numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack([reference[1], reference[3]], 0)) - numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack([reference[1], reference[3], reference[5]], 0)) - numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack([reference[1], reference[5]], 0)) - numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack([reference[2:4, 1], reference[2:4, 3]], 1)) - numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) - numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) - numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) - - lst = [list(range(i, i+10)) for i in range(0, 100, 10)] - tensor = Tensor(lst) - for _ in range(100): - idx1_start = random.randrange(10) - idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) - idx1_step = random.randrange(1, 8) - idx1 = slice(idx1_start, idx1_end, idx1_step) - if random.randrange(2) == 0: - idx2_start = random.randrange(10) - idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) - idx2_step = random.randrange(1, 8) - idx2 = slice(idx2_start, idx2_end, idx2_step) - lst_indexed = [l[idx2] for l in lst[idx1]] - tensor_indexed = tensor[idx1, idx2] - else: - lst_indexed = lst[idx1] - tensor_indexed = tensor[idx1] - numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed)) - - # self.assertRaises(ValueError, lambda: reference[1:9:0]) - # self.assertRaises(ValueError, lambda: reference[1:9:-1]) - - # self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) - # self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) - # self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) - - # self.assertRaises(IndexError, lambda: reference[0.0]) - # self.assertRaises(TypeError, lambda: reference[0.0:2.0]) - # self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) - # self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) - # self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) - # self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) - - # def delitem(): del reference[0] - # self.assertRaises(TypeError, delitem) - - def test_advancedindex(self): - # integer array indexing - - # pick a random valid indexer type - def ri(indices): - choice = random.randint(0, 1) - # TODO: we do not support tuple of list for index now - if choice == 0: return Tensor(indices) - if choice == 1: return list(indices) - return tuple(indices) - - def validate_indexing(x): - numpy_testing_assert_equal_helper(x[[0]], consec((1,))) - numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,))) - numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4)) - numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3)) - numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3)) - numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5])) - - def validate_setting(x): - pass - # # TODO: we don't support setitem now - # x[[0]] = -2 - # numpy_testing_assert_equal_helper(x[[0]], np.array([-2])) - # x[[0]] = -1 - # numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1])) - # x[[2, 3, 4]] = 4 - # numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4])) - # x[ri([2, 3, 4]), ] = 3 - # numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3])) - # x[ri([0, 2, 4]), ] = np.array([5, 4, 3]) - # numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3])) - - # Case 1: Purely Integer Array Indexing - reference = consec((10,)) - validate_indexing(reference) - - # setting values - validate_setting(reference) - - # # Tensor with stride != 1 - # # strided is [1, 3, 5, 7] - # reference = consec((10,)) - # strided = np.array(()) - # strided.set_(reference.storage(), storage_offset=0, - # size=torch.Size([4]), stride=[2]) - - # numpy_testing_assert_equal_helper(strided[[0]], np.array([1])) - # numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1])) - # numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7])) - # numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5])) - # numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5])) - # numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ], - # np.array([[5, 3], [1, 7]])) - - # # stride is [4, 8] - # strided = np.array(()) - # strided.set_(reference.storage(), storage_offset=4, - # size=torch.Size([2]), stride=[4]) - # numpy_testing_assert_equal_helper(strided[[0]], np.array([5])) - # numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5])) - # numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9])) - # numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9])) - # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9])) - # numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ], - # np.array([[5, 9], [9, 5]])) - - # reference is 1 2 - # 3 4 - # 5 6 - reference = consec((3, 2)) - numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5])) - numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6])) - numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,))) - numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6)) - # # TODO: we don't support list of Tensors as index - # numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2])) - # numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6])) - # numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3])) - - # rows = ri([[0, 0], - # [1, 2]]) - # columns = [0], - # numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1], - # [3, 5]])) - - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1], - [4, 5]])) - rows = ri([[0, 0], - [1, 2]]) - columns = ri([[0, 1], - [1, 0]]) - numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2], - [4, 5]])) - - # # setting values - # reference[ri([0]), ri([1])] = -1 - # numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1])) - # reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4]) - # numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], - # np.array([-1, 2, -4])) - # reference[rows, columns] = np.array([[4, 6], [2, 3]]) - # numpy_testing_assert_equal_helper(reference[rows, columns], - # np.array([[4, 6], [2, 3]])) - - # Verify still works with Transposed (i.e. non-contiguous) Tensors - - reference = Tensor([[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]]).T - - # Transposed: [[0, 4, 8], - # [1, 5, 9], - # [2, 6, 10], - # [3, 7, 11]] - - numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2])) - numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6])) - numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0])) - numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6])) - # # TODO: we don't support list of Tensors as index - # numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4])) - # numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7])) - # numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1])) - - # rows = ri([[0, 0], - # [1, 2]]) - # columns = [0], - # numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]])) - - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]])) - rows = ri([[0, 0], - [1, 3]]) - columns = ri([[0, 1], - [1, 2]]) - numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]])) - - # # setting values - # reference[ri([0]), ri([1])] = -1 - # numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], - # np.array([-1])) - # reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4]) - # numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], - # np.array([-1, 2, -4])) - # reference[rows, columns] = np.array([[4, 6], [2, 3]]) - # numpy_testing_assert_equal_helper(reference[rows, columns], - # np.array([[4, 6], [2, 3]])) - - # # stride != 1 - - # # strided is [[1 3 5 7], - # # [9 11 13 15]] - - # reference = torch.arange(0., 24).view(3, 8) - # strided = np.array(()) - # strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), - # stride=[8, 2]) - - # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])], - # np.array([1, 9])) - # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])], - # np.array([3, 11])) - # numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])], - # np.array([1])) - # numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])], - # np.array([15])) - # numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]], - # np.array([1, 7])) - # numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], - # np.array([9, 11, 11, 9, 15])) - # numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - # np.array([1, 3, 9, 9])) - - # rows = ri([[0, 0], - # [1, 1]]) - # columns = [0], - # numpy_testing_assert_equal_helper(strided[rows, columns], - # np.array([[1, 1], [9, 9]])) - - # rows = ri([[0, 1], - # [1, 0]]) - # columns = ri([1, 2]) - # numpy_testing_assert_equal_helper(strided[rows, columns], - # np.array([[3, 13], [11, 5]])) - # rows = ri([[0, 0], - # [1, 1]]) - # columns = ri([[0, 1], - # [1, 2]]) - # numpy_testing_assert_equal_helper(strided[rows, columns], - # np.array([[1, 3], [11, 13]])) - - # # setting values - - # # strided is [[10, 11], - # # [17, 18]] - - # reference = torch.arange(0., 24).view(3, 8) - # strided = np.array(()) - # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - # stride=[7, 1]) - # numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], - # np.array([11])) - # strided[ri([0]), ri([1])] = -1 - # numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], - # np.array([-1])) - - # reference = torch.arange(0., 24).view(3, 8) - # strided = np.array(()) - # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - # stride=[7, 1]) - # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], - # np.array([11, 17])) - # strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2]) - # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], - # np.array([-1, 2])) - - # reference = torch.arange(0., 24).view(3, 8) - # strided = np.array(()) - # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - # stride=[7, 1]) - - # rows = ri([[0], - # [1]]) - # columns = ri([[0, 1], - # [0, 1]]) - # numpy_testing_assert_equal_helper(strided[rows, columns], - # np.array([[10, 11], [17, 18]])) - # strided[rows, columns] = np.array([[4, 6], [2, 3]]) - # numpy_testing_assert_equal_helper(strided[rows, columns], - # np.array([[4, 6], [2, 3]])) - - # Tests using less than the number of dims, and ellipsis - - # reference is 1 2 - # 3 4 - # 5 6 - reference = consec((3, 2)) - numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]])) - numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]])) - numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]])) - - # verify too many indices fails - with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])] - - # # test invalid index fails - # reference = torch.empty(10) - # # can't test cuda because it is a device assert - # if not reference.is_cuda: - # for err_idx in (10, -11): - # with self.assertRaisesRegex(IndexError, r'out of'): - # reference[err_idx] - # with self.assertRaisesRegex(IndexError, r'out of'): - # reference[torch.LongTensor([err_idx]).to(device)] - # with self.assertRaisesRegex(IndexError, r'out of'): - # reference[[err_idx]] - - # def tensor_indices_to_np(tensor, indices): - # # convert the Torch Tensor to a numpy array - # tensor = tensor.to(device='cpu') - # npt = tensor.numpy() - - # # convert indices - # idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else - # i for i in indices) - - # return npt, idxs - - # def get_numpy(tensor, indices): - # npt, idxs = tensor_indices_to_np(tensor, indices) - - # # index and return as a Torch Tensor - # return np.array(npt[idxs]) - - # def set_numpy(tensor, indices, value): - # if not isinstance(value, int): - # if self.device_type != 'cpu': - # value = value.cpu() - # value = value.numpy() - - # npt, idxs = tensor_indices_to_np(tensor, indices) - # npt[idxs] = value - # return npt - - # def assert_get_eq(tensor, indexer): - # numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer)) - - # def assert_set_eq(tensor, indexer, val): - # pyt = tensor.clone() - # numt = tensor.clone() - # pyt[indexer] = val - # numt = np.array(set_numpy(numt, indexer, val)) - # numpy_testing_assert_equal_helper(pyt, numt) - - # def assert_backward_eq(tensor, indexer): - # cpu = tensor.float().clone().detach().requires_grad_(True) - # outcpu = cpu[indexer] - # gOcpu = torch.rand_like(outcpu) - # outcpu.backward(gOcpu) - # dev = cpu.to(device).detach().requires_grad_(True) - # outdev = dev[indexer] - # outdev.backward(gOcpu.to(device)) - # numpy_testing_assert_equal_helper(cpu.grad, dev.grad) - - # def get_set_tensor(indexed, indexer): - # set_size = indexed[indexer].size() - # set_count = indexed[indexer].numel() - # set_tensor = torch.randperm(set_count).view(set_size).double().to(device) - # return set_tensor - - # # Tensor is 0 1 2 3 4 - # # 5 6 7 8 9 - # # 10 11 12 13 14 - # # 15 16 17 18 19 - # reference = torch.arange(0., 20).view(4, 5) - - # indices_to_test = [ - # # grab the second, fourth columns - # [slice(None), [1, 3]], - - # # first, third rows, - # [[0, 2], slice(None)], - - # # weird shape - # [slice(None), [[0, 1], - # [2, 3]]], - # # negatives - # [[-1], [0]], - # [[0, 2], [-1]], - # [slice(None), [-1]], - # ] - - # # only test dupes on gets - # get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] - - # for indexer in get_indices_to_test: - # assert_get_eq(reference, indexer) - # if self.device_type != 'cpu': - # assert_backward_eq(reference, indexer) - - # for indexer in indices_to_test: - # assert_set_eq(reference, indexer, 44) - # assert_set_eq(reference, - # indexer, - # get_set_tensor(reference, indexer)) - - # reference = torch.arange(0., 160).view(4, 8, 5) - - # indices_to_test = [ - # [slice(None), slice(None), [0, 3, 4]], - # [slice(None), [2, 4, 5, 7], slice(None)], - # [[2, 3], slice(None), slice(None)], - # [slice(None), [0, 2, 3], [1, 3, 4]], - # [slice(None), [0], [1, 2, 4]], - # [slice(None), [0, 1, 3], [4]], - # [slice(None), [[0, 1], [1, 0]], [[2, 3]]], - # [slice(None), [[0, 1], [2, 3]], [[0]]], - # [slice(None), [[5, 6]], [[0, 3], [4, 4]]], - # [[0, 2, 3], [1, 3, 4], slice(None)], - # [[0], [1, 2, 4], slice(None)], - # [[0, 1, 3], [4], slice(None)], - # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - # [[[0, 1], [1, 0]], [[2, 3]], slice(None)], - # [[[0, 1], [2, 3]], [[0]], slice(None)], - # [[[2, 1]], [[0, 3], [4, 4]], slice(None)], - # [[[2]], [[0, 3], [4, 1]], slice(None)], - # # non-contiguous indexing subspace - # [[0, 2, 3], slice(None), [1, 3, 4]], - - # # less dim, ellipsis - # [[0, 2], ], - # [[0, 2], slice(None)], - # [[0, 2], Ellipsis], - # [[0, 2], slice(None), Ellipsis], - # [[0, 2], Ellipsis, slice(None)], - # [[0, 2], [1, 3]], - # [[0, 2], [1, 3], Ellipsis], - # [Ellipsis, [1, 3], [2, 3]], - # [Ellipsis, [2, 3, 4]], - # [Ellipsis, slice(None), [2, 3, 4]], - # [slice(None), Ellipsis, [2, 3, 4]], - - # # ellipsis counts for nothing - # [Ellipsis, slice(None), slice(None), [0, 3, 4]], - # [slice(None), Ellipsis, slice(None), [0, 3, 4]], - # [slice(None), slice(None), Ellipsis, [0, 3, 4]], - # [slice(None), slice(None), [0, 3, 4], Ellipsis], - # [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], - # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], - # ] - - # for indexer in indices_to_test: - # assert_get_eq(reference, indexer) - # assert_set_eq(reference, indexer, 212) - # assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) - # if torch.cuda.is_available(): - # assert_backward_eq(reference, indexer) - - # reference = torch.arange(0., 1296).view(3, 9, 8, 6) - - # indices_to_test = [ - # [slice(None), slice(None), slice(None), [0, 3, 4]], - # [slice(None), slice(None), [2, 4, 5, 7], slice(None)], - # [slice(None), [2, 3], slice(None), slice(None)], - # [[1, 2], slice(None), slice(None), slice(None)], - # [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], - # [slice(None), slice(None), [0], [1, 2, 4]], - # [slice(None), slice(None), [0, 1, 3], [4]], - # [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], - # [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], - # [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], - # [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], - # [slice(None), [0], [1, 2, 4], slice(None)], - # [slice(None), [0, 1, 3], [4], slice(None)], - # [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], - # [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], - # [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], - # [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], - # [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], - # [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], - # [[0], [1, 2, 4], slice(None), slice(None)], - # [[0, 1, 2], [4], slice(None), slice(None)], - # [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], - # [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], - # [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], - # [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], - # [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], - # [slice(None), [2, 3, 4], [1, 3, 4], [4]], - # [slice(None), [0, 1, 3], [4], [1, 3, 4]], - # [slice(None), [6], [0, 2, 3], [1, 3, 4]], - # [slice(None), [2, 3, 5], [3], [4]], - # [slice(None), [0], [4], [1, 3, 4]], - # [slice(None), [6], [0, 2, 3], [1]], - # [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], - # [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], - # [[2, 0, 1], [1, 2, 3], [4], slice(None)], - # [[0, 1, 2], [4], [1, 3, 4], slice(None)], - # [[0], [0, 2, 3], [1, 3, 4], slice(None)], - # [[0, 2, 1], [3], [4], slice(None)], - # [[0], [4], [1, 3, 4], slice(None)], - # [[1], [0, 2, 3], [1], slice(None)], - # [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], - - # # less dim, ellipsis - # [Ellipsis, [0, 3, 4]], - # [Ellipsis, slice(None), [0, 3, 4]], - # [Ellipsis, slice(None), slice(None), [0, 3, 4]], - # [slice(None), Ellipsis, [0, 3, 4]], - # [slice(None), slice(None), Ellipsis, [0, 3, 4]], - # [slice(None), [0, 2, 3], [1, 3, 4]], - # [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], - # [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], - # [[0], [1, 2, 4]], - # [[0], [1, 2, 4], slice(None)], - # [[0], [1, 2, 4], Ellipsis], - # [[0], [1, 2, 4], Ellipsis, slice(None)], - # [[1], ], - # [[0, 2, 1], [3], [4]], - # [[0, 2, 1], [3], [4], slice(None)], - # [[0, 2, 1], [3], [4], Ellipsis], - # [Ellipsis, [0, 2, 1], [3], [4]], - # ] - - # for indexer in indices_to_test: - # assert_get_eq(reference, indexer) - # assert_set_eq(reference, indexer, 1333) - # assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) - # indices_to_test += [ - # [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], - # [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], - # ] - # for indexer in indices_to_test: - # assert_get_eq(reference, indexer) - # assert_set_eq(reference, indexer, 1333) - # if self.device_type != 'cpu': - # assert_backward_eq(reference, indexer) - - # def test_advancedindex_big(self): - # reference = Tensor.arange(123344) - # numpy_testing_assert_equal_helper(reference[[0, 123, 44488, 68807, 123343],], np.array([0, 123, 44488, 68807, 123343])) - - # def test_set_item_to_scalar_tensor(self): - # m = random.randint(1, 10) - # n = random.randint(1, 10) - # z = torch.randn([m, n]) - # a = 1.0 - # w = np.array(a, requires_grad=True) - # z[:, 0] = w - # z.sum().backward() - # numpy_testing_assert_equal_helper(w.grad, m * a) - - def test_single_int(self): - v = Tensor.randn(5, 7, 3) - numpy_testing_assert_equal_helper(v[4].shape, (7, 3)) - - def test_multiple_int(self): - v = Tensor.randn(5, 7, 3) - numpy_testing_assert_equal_helper(v[4].shape, (7, 3)) - numpy_testing_assert_equal_helper(v[4, :, 1].shape, (7,)) - - def test_none(self): - v = Tensor.randn(5, 7, 3) - numpy_testing_assert_equal_helper(v[None].shape, (1, 5, 7, 3)) - numpy_testing_assert_equal_helper(v[:, None].shape, (5, 1, 7, 3)) - numpy_testing_assert_equal_helper(v[:, None, None].shape, (5, 1, 1, 7, 3)) - numpy_testing_assert_equal_helper(v[..., None].shape, (5, 7, 3, 1)) - - def test_step(self): - v = Tensor.arange(10) - numpy_testing_assert_equal_helper(v[::1], v) - numpy_testing_assert_equal_helper(v[::2], [0, 2, 4, 6, 8]) - numpy_testing_assert_equal_helper(v[::3], [0, 3, 6, 9]) - numpy_testing_assert_equal_helper(v[::11], [0]) - numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5]) - - # def test_step_assignment(self): - # v = torch.zeros(4, 4) - # v[0, 1::2] = np.array([3., 4.]) - # numpy_testing_assert_equal_helper(v[0].tolist(), [0, 3, 0, 4]) - # numpy_testing_assert_equal_helper(v[1:].sum(), 0) - - # def test_bool_indices(self): - # v = Tensor.randn(5, 7, 3) - # boolIndices = np.array([True, False, True, True, False], dtype=bool) - # numpy_testing_assert_equal_helper(v[boolIndices].shape, (3, 7, 3)) - # numpy_testing_assert_equal_helper(v[boolIndices], Tensor.stack([v[0], v[2], v[3]])) - - # v = np.array([True, False, True], dtype=torch.bool) - # boolIndices = np.array([True, False, False], dtype=torch.bool) - # uint8Indices = np.array([1, 0, 0], dtype=torch.uint8) - # with warnings.catch_warnings(record=True) as w: - # numpy_testing_assert_equal_helper(v[boolIndices].shape, v[uint8Indices].shape) - # numpy_testing_assert_equal_helper(v[boolIndices], v[uint8Indices]) - # numpy_testing_assert_equal_helper(v[boolIndices], tensor([True], dtype=torch.bool)) - # numpy_testing_assert_equal_helper(len(w), 2) - - # def test_bool_indices_accumulate(self): - # mask = torch.zeros(size=(10, ), dtype=torch.bool) - # y = torch.ones(size=(10, 10)) - # y.index_put_((mask, ), y[mask], accumulate=True) - # numpy_testing_assert_equal_helper(y, torch.ones(size=(10, 10))) - - # def test_multiple_bool_indices(self): - # v = torch.randn(5, 7, 3) - # # note: these broadcast together and are transposed to the first dim - # mask1 = np.array([1, 0, 1, 1, 0], dtype=torch.bool) - # mask2 = np.array([1, 1, 1], dtype=torch.bool) - # numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7)) - - # def test_byte_mask(self): - # v = torch.randn(5, 7, 3) - # mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) - # with warnings.catch_warnings(record=True) as w: - # numpy_testing_assert_equal_helper(v[mask].shape, (3, 7, 3)) - # numpy_testing_assert_equal_helper(v[mask], torch.stack([v[0], v[2], v[3]])) - # numpy_testing_assert_equal_helper(len(w), 2) - - # v = np.array([1.]) - # numpy_testing_assert_equal_helper(v[v == 0], np.array([])) - - # def test_byte_mask_accumulate(self): - # mask = torch.zeros(size=(10, ), dtype=torch.uint8) - # y = torch.ones(size=(10, 10)) - # with warnings.catch_warnings(record=True) as w: - # warnings.simplefilter("always") - # y.index_put_((mask, ), y[mask], accumulate=True) - # numpy_testing_assert_equal_helper(y, torch.ones(size=(10, 10))) - # numpy_testing_assert_equal_helper(len(w), 2) - - # def test_index_put_accumulate_large_tensor(self): - # # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). - # N = (1 << 31) + 5 - # dt = torch.int8 - # a = torch.ones(N, dtype=dt) - # indices = np.array([-2, 0, -2, -1, 0, -1, 1], dtype=torch.long) - # values = np.array([6, 5, 6, 6, 5, 7, 11], dtype=dt) - - # a.index_put_((indices, ), values, accumulate=True) - - # numpy_testing_assert_equal_helper(a[0], 11) - # numpy_testing_assert_equal_helper(a[1], 12) - # numpy_testing_assert_equal_helper(a[2], 1) - # numpy_testing_assert_equal_helper(a[-3], 1) - # numpy_testing_assert_equal_helper(a[-2], 13) - # numpy_testing_assert_equal_helper(a[-1], 14) - - # a = torch.ones((2, N), dtype=dt) - # indices0 = np.array([0, -1, 0, 1], dtype=torch.long) - # indices1 = np.array([-2, -1, 0, 1], dtype=torch.long) - # values = np.array([12, 13, 10, 11], dtype=dt) - - # a.index_put_((indices0, indices1), values, accumulate=True) - - # numpy_testing_assert_equal_helper(a[0, 0], 11) - # numpy_testing_assert_equal_helper(a[0, 1], 1) - # numpy_testing_assert_equal_helper(a[1, 0], 1) - # numpy_testing_assert_equal_helper(a[1, 1], 12) - # numpy_testing_assert_equal_helper(a[:, 2], torch.ones(2, dtype=torch.int8)) - # numpy_testing_assert_equal_helper(a[:, -3], torch.ones(2, dtype=torch.int8)) - # numpy_testing_assert_equal_helper(a[0, -2], 13) - # numpy_testing_assert_equal_helper(a[1, -2], 1) - # numpy_testing_assert_equal_helper(a[-1, -1], 14) - # numpy_testing_assert_equal_helper(a[0, -1], 1) - - # def test_index_put_accumulate_expanded_values(self): - # # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227 - # # and verifies consistency with CPU result - # t = torch.zeros((5, 2)) - # t_dev = t.to(device) - # indices = [ - # np.array([0, 1, 2, 3]), - # np.array([1, ]), - # ] - # indices_dev = [i.to(device) for i in indices] - # values0d = np.array(1.0) - # values1d = np.array([1.0, ]) - - # out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) - # out_cpu = t.index_put_(indices, values0d, accumulate=True) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) - # out_cpu = t.index_put_(indices, values1d, accumulate=True) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # t = torch.zeros(4, 3, 2) - # t_dev = t.to(device) - - # indices = [ - # np.array([0, ]), - # torch.arange(3)[:, None], - # torch.arange(2)[None, :], - # ] - # indices_dev = [i.to(device) for i in indices] - # values1d = np.array([-1.0, -2.0]) - # values2d = np.array([[-1.0, -2.0], ]) - - # out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) - # out_cpu = t.index_put_(indices, values1d, accumulate=True) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) - # out_cpu = t.index_put_(indices, values2d, accumulate=True) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # def test_index_put_accumulate_non_contiguous(self): - # t = torch.zeros((5, 2, 2)) - # t_dev = t.to(device) - # t1 = t_dev[:, 0, :] - # t2 = t[:, 0, :] - # self.assertTrue(not t1.is_contiguous()) - # self.assertTrue(not t2.is_contiguous()) - - # indices = [np.array([0, 1]), ] - # indices_dev = [i.to(device) for i in indices] - # value = torch.randn(2, 2) - # out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True) - # out_cpu = t2.index_put_(indices, value, accumulate=True) - # self.assertTrue(not t1.is_contiguous()) - # self.assertTrue(not t2.is_contiguous()) - - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # def test_index_put_accumulate_with_optional_tensors(self): - # # TODO: replace with a better solution. - # # Currently, here using torchscript to put None into indices. - # # on C++ it gives indices as a list of 2 optional tensors: first is null and - # # the second is a valid tensor. - # @torch.jit.script - # def func(x, i, v): - # idx = [None, i] - # x.index_put_(idx, v, accumulate=True) - # return x - - # n = 4 - # t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) - # t_dev = t.to(device) - # indices = np.array([1, 0]) - # indices_dev = indices.to(device) - # value0d = np.array(10.0) - # value1d = np.array([1.0, 2.0]) - - # out_cuda = func(t_dev, indices_dev, value0d.cuda()) - # out_cpu = func(t, indices, value0d) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # out_cuda = func(t_dev, indices_dev, value1d.cuda()) - # out_cpu = func(t, indices, value1d) - # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) - - # def test_index_put_accumulate_duplicate_indices(self): - # for i in range(1, 512): - # # generate indices by random walk, this will create indices with - # # lots of duplicates interleaved with each other - # delta = torch.empty(i, dtype=torch.double).uniform_(-1, 1) - # indices = delta.cumsum(0).long() - - # input = torch.randn(indices.abs().max() + 1) - # values = torch.randn(indices.size(0)) - # output = input.index_put((indices,), values, accumulate=True) - - # input_list = input.tolist() - # indices_list = indices.tolist() - # values_list = values.tolist() - # for i, v in zip(indices_list, values_list): - # input_list[i] += v - - # numpy_testing_assert_equal_helper(output, input_list) - - # def test_index_ind_dtype(self): - # x = torch.randn(4, 4) - # ind_long = torch.randint(4, (4,), dtype=torch.long) - # ind_int = ind_long.int() - # src = torch.randn(4) - # ref = x[ind_long, ind_long] - # res = x[ind_int, ind_int] - # numpy_testing_assert_equal_helper(ref, res) - # ref = x[ind_long, :] - # res = x[ind_int, :] - # numpy_testing_assert_equal_helper(ref, res) - # ref = x[:, ind_long] - # res = x[:, ind_int] - # numpy_testing_assert_equal_helper(ref, res) - # # no repeating indices for index_put - # ind_long = torch.arange(4, dtype=torch.long) - # ind_int = ind_long.int() - # for accum in (True, False): - # inp_ref = x.clone() - # inp_res = x.clone() - # torch.index_put_(inp_ref, (ind_long, ind_long), src, accum) - # torch.index_put_(inp_res, (ind_int, ind_int), src, accum) - # numpy_testing_assert_equal_helper(inp_ref, inp_res) - - # def test_index_put_accumulate_empty(self): - # # Regression test for https://github.com/pytorch/pytorch/issues/94667 - # input = torch.rand([], dtype=torch.float32) - # with self.assertRaises(RuntimeError): - # input.index_put([], np.array([1.0]), True) - - # def test_multiple_byte_mask(self): - # v = torch.randn(5, 7, 3) - # # note: these broadcast together and are transposed to the first dim - # mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) - # mask2 = torch.ByteTensor([1, 1, 1]).to(device) - # with warnings.catch_warnings(record=True) as w: - # warnings.simplefilter("always") - # numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7)) - # numpy_testing_assert_equal_helper(len(w), 2) - - # def test_byte_mask2d(self): - # v = torch.randn(5, 7, 3) - # c = torch.randn(5, 7) - # num_ones = (c > 0).sum() - # r = v[c > 0] - # numpy_testing_assert_equal_helper(r.shape, (num_ones, 3)) - - # def test_jit_indexing(self): - # def fn1(x): - # x[x < 50] = 1.0 - # return x - - # def fn2(x): - # x[0:50] = 1.0 - # return x - - # scripted_fn1 = torch.jit.script(fn1) - # scripted_fn2 = torch.jit.script(fn2) - # data = torch.arange(100, dtype=torch.float) - # out = scripted_fn1(data.detach().clone()) - # ref = np.array(np.concatenate((np.ones(50), np.arange(50, 100))), dtype=torch.float) - # numpy_testing_assert_equal_helper(out, ref) - # out = scripted_fn2(data.detach().clone()) - # numpy_testing_assert_equal_helper(out, ref) - - # def test_int_indices(self): - # v = torch.randn(5, 7, 3) - # numpy_testing_assert_equal_helper(v[[0, 4, 2]].shape, (3, 7, 3)) - # numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3)) - # numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) - - # def test_index_put_src_datatype(self, dtype): - # src = torch.ones(3, 2, 4, dtype=dtype) - # vals = torch.ones(3, 2, 4, dtype=dtype) - # indices = (np.array([0, 2, 1]),) - # res = src.index_put_(indices, vals, accumulate=True) - # numpy_testing_assert_equal_helper(res.shape, src.shape) - - # def test_index_src_datatype(self, dtype): - # src = torch.ones(3, 2, 4, dtype=dtype) - # # test index - # res = src[[0, 2, 1], :, :] - # numpy_testing_assert_equal_helper(res.shape, src.shape) - # # test index_put, no accum - # src[[0, 2, 1], :, :] = res - # numpy_testing_assert_equal_helper(res.shape, src.shape) - - # def test_int_indices2d(self): - # # From the NumPy indexing example - # x = torch.arange(0, 12).view(4, 3) - # rows = np.array([[0, 0], [3, 3]]) - # columns = np.array([[0, 2], [0, 2]]) - # numpy_testing_assert_equal_helper(x[rows, columns].tolist(), [[0, 2], [9, 11]]) - - # def test_int_indices_broadcast(self): - # # From the NumPy indexing example - # x = torch.arange(0, 12).view(4, 3) - # rows = np.array([0, 3]) - # columns = np.array([0, 2]) - # result = x[rows[:, None], columns] - # numpy_testing_assert_equal_helper(result.tolist(), [[0, 2], [9, 11]]) - - # def test_empty_index(self): - # x = torch.arange(0, 12).view(4, 3) - # idx = np.array([], dtype=torch.long) - # numpy_testing_assert_equal_helper(x[idx].numel(), 0) - - # # empty assignment should have no effect but not throw an exception - # y = x.clone() - # y[idx] = -1 - # numpy_testing_assert_equal_helper(x, y) - - # mask = torch.zeros(4, 3).bool() - # y[mask] = -1 - # numpy_testing_assert_equal_helper(x, y) - - # def test_empty_ndim_index(self): - # x = torch.randn(5) - # numpy_testing_assert_equal_helper(torch.empty(0, 2), x[torch.empty(0, 2, dtype=torch.int64)]) - - # x = torch.randn(2, 3, 4, 5) - # numpy_testing_assert_equal_helper(torch.empty(2, 0, 6, 4, 5), - # x[:, torch.empty(0, 6, dtype=torch.int64)]) - - # x = torch.empty(10, 0) - # numpy_testing_assert_equal_helper(x[[1, 2]].shape, (2, 0)) - # numpy_testing_assert_equal_helper(x[[], []].shape, (0,)) - # with self.assertRaisesRegex(IndexError, 'for dimension with size 0'): - # x[:, [0, 1]] - - # def test_empty_ndim_index_bool(self): - # x = torch.randn(5) - # self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8)]) - - # def test_empty_slice(self): - # x = torch.randn(2, 3, 4, 5) - # y = x[:, :, :, 1] - # z = y[:, 1:1, :] - # numpy_testing_assert_equal_helper((2, 0, 4), z.shape) - # # this isn't technically necessary, but matches NumPy stride calculations. - # numpy_testing_assert_equal_helper((60, 20, 5), z.stride()) - # self.assertTrue(z.is_contiguous()) - - # def test_index_getitem_copy_bools_slices(self): - # true = np.array(1, dtype=torch.uint8) - # false = np.array(0, dtype=torch.uint8) - - # tensors = [torch.randn(2, 3), np.array(3.)] - - # for a in tensors: - # self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) - # numpy_testing_assert_equal_helper(torch.empty(0, *a.shape), a[False]) - # self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) - # numpy_testing_assert_equal_helper(torch.empty(0, *a.shape), a[false]) - # numpy_testing_assert_equal_helper(a.data_ptr(), a[None].data_ptr()) - # numpy_testing_assert_equal_helper(a.data_ptr(), a[...].data_ptr()) - - # def test_index_setitem_bools_slices(self): - # true = np.array(1, dtype=torch.uint8) - # false = np.array(0, dtype=torch.uint8) - - # tensors = [torch.randn(2, 3), np.array(3)] - - # for a in tensors: - # # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s - # # (some of these ops already prefix a 1 to the size) - # neg_ones = torch.ones_like(a) * -1 - # neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) - # a[True] = neg_ones_expanded - # numpy_testing_assert_equal_helper(a, neg_ones) - # a[False] = 5 - # numpy_testing_assert_equal_helper(a, neg_ones) - # a[true] = neg_ones_expanded * 2 - # numpy_testing_assert_equal_helper(a, neg_ones * 2) - # a[false] = 5 - # numpy_testing_assert_equal_helper(a, neg_ones * 2) - # a[None] = neg_ones_expanded * 3 - # numpy_testing_assert_equal_helper(a, neg_ones * 3) - # a[...] = neg_ones_expanded * 4 - # numpy_testing_assert_equal_helper(a, neg_ones * 4) - # if a.dim() == 0: - # with self.assertRaises(IndexError): - # a[:] = neg_ones_expanded * 5 - - # def test_index_scalar_with_bool_mask(self): - # a = np.array(1) - # uintMask = np.array(True, dtype=torch.uint8) - # boolMask = np.array(True, dtype=torch.bool) - # numpy_testing_assert_equal_helper(a[uintMask], a[boolMask]) - # numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype) - - # a = np.array(True, dtype=torch.bool) - # numpy_testing_assert_equal_helper(a[uintMask], a[boolMask]) - # numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype) - - # def test_setitem_expansion_error(self): - # true = np.array(True) - # a = torch.randn(2, 3) - # # check prefix with non-1s doesn't work - # a_expanded = a.expand(torch.Size([5, 1]) + a.size()) - # # NumPy: ValueError - # with self.assertRaises(RuntimeError): - # a[True] = a_expanded - # with self.assertRaises(RuntimeError): - # a[true] = a_expanded - - # def test_getitem_scalars(self): - # zero = np.array(0, dtype=torch.int64) - # one = np.array(1, dtype=torch.int64) - - # # non-scalar indexed with scalars - # a = torch.randn(2, 3) - # numpy_testing_assert_equal_helper(a[0], a[zero]) - # numpy_testing_assert_equal_helper(a[0][1], a[zero][one]) - # numpy_testing_assert_equal_helper(a[0, 1], a[zero, one]) - # numpy_testing_assert_equal_helper(a[0, one], a[zero, 1]) - - # # indexing by a scalar should slice (not copy) - # numpy_testing_assert_equal_helper(a[0, 1].data_ptr(), a[zero, one].data_ptr()) - # numpy_testing_assert_equal_helper(a[1].data_ptr(), a[one.int()].data_ptr()) - # numpy_testing_assert_equal_helper(a[1].data_ptr(), a[one.short()].data_ptr()) - - # # scalar indexed with scalar - # r = torch.randn(()) - # with self.assertRaises(IndexError): - # r[:] - # with self.assertRaises(IndexError): - # r[zero] - # numpy_testing_assert_equal_helper(r, r[...]) - - # def test_setitem_scalars(self): - # zero = np.array(0, dtype=torch.int64) - - # # non-scalar indexed with scalars - # a = torch.randn(2, 3) - # a_set_with_number = a.clone() - # a_set_with_scalar = a.clone() - # b = torch.randn(3) - - # a_set_with_number[0] = b - # a_set_with_scalar[zero] = b - # numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar) - # a[1, zero] = 7.7 - # numpy_testing_assert_equal_helper(7.7, a[1, 0]) - - # # scalar indexed with scalars - # r = torch.randn(()) - # with self.assertRaises(IndexError): - # r[:] = 8.8 - # with self.assertRaises(IndexError): - # r[zero] = 8.8 - # r[...] = 9.9 - # numpy_testing_assert_equal_helper(9.9, r) - - # def test_basic_advanced_combined(self): - # # From the NumPy indexing example - # x = torch.arange(0, 12).view(4, 3) - # numpy_testing_assert_equal_helper(x[1:2, 1:3], x[1:2, [1, 2]]) - # numpy_testing_assert_equal_helper(x[1:2, 1:3].tolist(), [[4, 5]]) - - # # Check that it is a copy - # unmodified = x.clone() - # x[1:2, [1, 2]].zero_() - # numpy_testing_assert_equal_helper(x, unmodified) - - # # But assignment should modify the original - # unmodified = x.clone() - # x[1:2, [1, 2]] = 0 - # self.assertNotEqual(x, unmodified) - - # def test_int_assignment(self): - # x = torch.arange(0, 4).view(2, 2) - # x[1] = 5 - # numpy_testing_assert_equal_helper(x.tolist(), [[0, 1], [5, 5]]) - - # x = torch.arange(0, 4).view(2, 2) - # x[1] = torch.arange(5, 7) - # numpy_testing_assert_equal_helper(x.tolist(), [[0, 1], [5, 6]]) - - # def test_byte_tensor_assignment(self): - # x = torch.arange(0., 16).view(4, 4) - # b = torch.ByteTensor([True, False, True, False]).to(device) - # value = np.array([3., 4., 5., 6.]) - - # with warnings.catch_warnings(record=True) as w: - # x[b] = value - # numpy_testing_assert_equal_helper(len(w), 1) - - # numpy_testing_assert_equal_helper(x[0], value) - # numpy_testing_assert_equal_helper(x[1], torch.arange(4., 8)) - # numpy_testing_assert_equal_helper(x[2], value) - # numpy_testing_assert_equal_helper(x[3], torch.arange(12., 16)) - - # def test_variable_slicing(self): - # x = torch.arange(0, 16).view(4, 4) - # indices = torch.IntTensor([0, 1]).to(device) - # i, j = indices - # numpy_testing_assert_equal_helper(x[i:j], x[0:1]) - - # def test_ellipsis_tensor(self): - # x = torch.arange(0, 9).view(3, 3) - # idx = np.array([0, 2]) - # numpy_testing_assert_equal_helper(x[..., idx].tolist(), [[0, 2], - # [3, 5], - # [6, 8]]) - # numpy_testing_assert_equal_helper(x[idx, ...].tolist(), [[0, 1, 2], - # [6, 7, 8]]) - - # def test_unravel_index_errors(self): - # with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): - # torch.unravel_index( - # np.array(0.5), - # (2, 2)) - - # with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): - # torch.unravel_index( - # np.array([]), - # (10, 3, 5)) - - # with self.assertRaisesRegex(TypeError, r"expected 'shape' to be int or sequence"): - # torch.unravel_index( - # np.array([1], dtype=torch.int64), - # np.array([1, 2, 3])) - - # with self.assertRaisesRegex(TypeError, r"expected 'shape' sequence to only contain ints"): - # torch.unravel_index( - # np.array([1], dtype=torch.int64), - # (1, 2, 2.0)) - - # with self.assertRaisesRegex(ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"): - # torch.unravel_index( - # np.array(0), - # (2, -3)) - - # def test_invalid_index(self): - # x = torch.arange(0, 16).view(4, 4) - # self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"]) - - # def test_out_of_bound_index(self): - # x = torch.arange(0, 100).view(2, 5, 10) - # self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5]) - # self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5]) - # self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10', - # lambda: x[0, 1, 15]) - # self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10', - # lambda: x[:, :, 12]) - - # def test_zero_dim_index(self): - # x = np.array(10) - # numpy_testing_assert_equal_helper(x, x.item()) - - # def runner(): - # print(x[0]) - # return x[0] - - # self.assertRaisesRegex(IndexError, 'invalid index', runner) - - # def test_invalid_device(self): - # idx = np.array([0, 1]) - # b = torch.zeros(5) - # c = np.array([1., 2.], device="cpu") - - # for accumulate in [True, False]: - # self.assertRaises(RuntimeError, lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate)) - - # def test_cpu_indices(self): - # idx = np.array([0, 1]) - # b = torch.zeros(2) - # x = torch.ones(10) - # x[idx] = b # index_put_ - # ref = torch.ones(10) - # ref[:2] = 0 - # numpy_testing_assert_equal_helper(x, ref) - # out = x[idx] # index - # numpy_testing_assert_equal_helper(out, torch.zeros(2)) - - # def test_take_along_dim(self, dtype): - # def _test_against_numpy(t, indices, dim): - # actual = torch.take_along_dim(t, indices, dim=dim) - # t_np = t.cpu().numpy() - # indices_np = indices.cpu().numpy() - # expected = np.take_along_axis(t_np, indices_np, axis=dim) - # numpy_testing_assert_equal_helper(actual, expected) - - # for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]: - # for noncontiguous in [True, False]: - # t = make_tensor(shape, dtype=dtype, noncontiguous=noncontiguous) - # for dim in list(range(t.ndim)) + [None]: - # if dim is None: - # indices = torch.argsort(t.view(-1)) - # else: - # indices = torch.argsort(t, dim=dim) - - # _test_against_numpy(t, indices, dim) - - # # test broadcasting - # t = torch.ones((3, 4, 1)) - # indices = torch.ones((1, 2, 5), dtype=torch.long) - - # _test_against_numpy(t, indices, 1) - - # # test empty indices - # t = torch.ones((3, 4, 5)) - # indices = torch.ones((3, 0, 5), dtype=torch.long) - - # _test_against_numpy(t, indices, 1) - - # def test_take_along_dim_invalid(self, dtype): - # shape = (2, 3, 1, 4) - # dim = 0 - # t = make_tensor(shape, dtype=dtype) - # indices = torch.argsort(t, dim=dim) - - # # dim of `t` and `indices` does not match - # with self.assertRaisesRegex(RuntimeError, - # "input and indices should have the same number of dimensions"): - # torch.take_along_dim(t, indices[0], dim=0) - - # # invalid `indices` dtype - # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): - # torch.take_along_dim(t, indices.to(torch.bool), dim=0) - - # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): - # torch.take_along_dim(t, indices.to(torch.float), dim=0) - - # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): - # torch.take_along_dim(t, indices.to(torch.int32), dim=0) - - # # invalid axis - # with self.assertRaisesRegex(IndexError, "Dimension out of range"): - # torch.take_along_dim(t, indices, dim=-7) - - # with self.assertRaisesRegex(IndexError, "Dimension out of range"): - # torch.take_along_dim(t, indices, dim=7) - - # def test_gather_take_along_dim_cross_device(self, dtype): - # shape = (2, 3, 1, 4) - # dim = 0 - # t = make_tensor(shape, dtype=dtype) - # indices = torch.argsort(t, dim=dim) - - # with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - # torch.gather(t, 0, indices.cpu()) - - # with self.assertRaisesRegex(RuntimeError, - # r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"): - # torch.take_along_dim(t, indices.cpu(), dim=0) - - # with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): - # torch.gather(t.cpu(), 0, indices) - - # with self.assertRaisesRegex(RuntimeError, - # r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"): - # torch.take_along_dim(t.cpu(), indices, dim=0) - - # def test_cuda_broadcast_index_use_deterministic_algorithms(self): - # with DeterministicGuard(True): - # idx1 = np.array([0]) - # idx2 = np.array([2, 6]) - # idx3 = np.array([1, 5, 7]) - - # tensor_a = torch.rand(13, 11, 12, 13, 12).cpu() - # tensor_b = tensor_a.to(device=device) - # tensor_a[idx1] = 1.0 - # tensor_a[idx1, :, idx2, idx2, :] = 2.0 - # tensor_a[:, idx1, idx3, :, idx3] = 3.0 - # tensor_b[idx1] = 1.0 - # tensor_b[idx1, :, idx2, idx2, :] = 2.0 - # tensor_b[:, idx1, idx3, :, idx3] = 3.0 - # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) - - # tensor_a = torch.rand(10, 11).cpu() - # tensor_b = tensor_a.to(device=device) - # tensor_a[idx3] = 1.0 - # tensor_a[idx2, :] = 2.0 - # tensor_a[:, idx2] = 3.0 - # tensor_a[:, idx1] = 4.0 - # tensor_b[idx3] = 1.0 - # tensor_b[idx2, :] = 2.0 - # tensor_b[:, idx2] = 3.0 - # tensor_b[:, idx1] = 4.0 - # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) - - # tensor_a = torch.rand(10, 10).cpu() - # tensor_b = tensor_a.to(device=device) - # tensor_a[[8]] = 1.0 - # tensor_b[[8]] = 1.0 - # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) - - # tensor_a = torch.rand(10).cpu() - # tensor_b = tensor_a.to(device=device) - # tensor_a[6] = 1.0 - # tensor_b[6] = 1.0 - # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) + def test_index(self): + reference = consec((3, 3, 3)) + + numpy_testing_assert_equal_helper(reference[0], consec((3, 3))) + numpy_testing_assert_equal_helper(reference[1], consec((3, 3), 10)) + numpy_testing_assert_equal_helper(reference[2], consec((3, 3), 19)) + numpy_testing_assert_equal_helper(reference[0, 1], consec((3,), 4)) + numpy_testing_assert_equal_helper(reference[0:2], consec((2, 3, 3))) + numpy_testing_assert_equal_helper(reference[2, 2, 2], 27) + numpy_testing_assert_equal_helper(reference[:], consec((3, 3, 3))) + + # indexing with Ellipsis + numpy_testing_assert_equal_helper( + reference[..., 2], + np.array([[3.0, 6.0, 9.0], [12.0, 15.0, 18.0], [21.0, 24.0, 27.0]]), + ) + numpy_testing_assert_equal_helper( + reference[0, ..., 2], np.array([3.0, 6.0, 9.0]) + ) + numpy_testing_assert_equal_helper(reference[..., 2], reference[:, :, 2]) + numpy_testing_assert_equal_helper(reference[0, ..., 2], reference[0, :, 2]) + numpy_testing_assert_equal_helper(reference[0, 2, ...], reference[0, 2]) + numpy_testing_assert_equal_helper(reference[..., 2, 2, 2], 27) + numpy_testing_assert_equal_helper(reference[2, ..., 2, 2], 27) + numpy_testing_assert_equal_helper(reference[2, 2, ..., 2], 27) + numpy_testing_assert_equal_helper(reference[2, 2, 2, ...], 27) + numpy_testing_assert_equal_helper(reference[...], reference) + + reference_5d = consec((3, 3, 3, 3, 3)) + numpy_testing_assert_equal_helper( + reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0] + ) + numpy_testing_assert_equal_helper( + reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0] + ) + numpy_testing_assert_equal_helper( + reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1] + ) + numpy_testing_assert_equal_helper(reference_5d[...], reference_5d) + + # None indexing + numpy_testing_assert_equal_helper(reference[2, None], reference[2].unsqueeze(0)) + numpy_testing_assert_equal_helper( + reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0) + ) + numpy_testing_assert_equal_helper( + reference[2:4, None], reference[2:4].unsqueeze(1) + ) + numpy_testing_assert_equal_helper( + reference[None, 2, None, None], + reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0), + ) + numpy_testing_assert_equal_helper( + reference[None, 2:5, None, None], + reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2), + ) + + # indexing 0-length slice + numpy_testing_assert_equal_helper(np.empty((0, 3, 3)), reference[slice(0)]) + numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[slice(0), 2]) + numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[2, slice(0)]) + numpy_testing_assert_equal_helper(np.empty([]), reference[2, 1:1, 2]) + + # indexing with step + reference = consec((10, 10, 10)) + numpy_testing_assert_equal_helper( + reference[1:5:2], Tensor.stack([reference[1], reference[3]], 0) + ) + numpy_testing_assert_equal_helper( + reference[1:6:2], + Tensor.stack([reference[1], reference[3], reference[5]], 0), + ) + numpy_testing_assert_equal_helper( + reference[1:9:4], Tensor.stack([reference[1], reference[5]], 0) + ) + numpy_testing_assert_equal_helper( + reference[2:4, 1:5:2], + Tensor.stack([reference[2:4, 1], reference[2:4, 3]], 1), + ) + numpy_testing_assert_equal_helper( + reference[3, 1:6:2], + Tensor.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0), + ) + numpy_testing_assert_equal_helper( + reference[None, 2, 1:9:4], + Tensor.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0), + ) + numpy_testing_assert_equal_helper( + reference[:, 2, 1:6:2], + Tensor.stack( + [reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1 + ), + ) + + lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] + tensor = Tensor(lst) + for _ in range(100): + idx1_start = random.randrange(10) + idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) + idx1_step = random.randrange(1, 8) + idx1 = slice(idx1_start, idx1_end, idx1_step) + if random.randrange(2) == 0: + idx2_start = random.randrange(10) + idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) + idx2_step = random.randrange(1, 8) + idx2 = slice(idx2_start, idx2_end, idx2_step) + lst_indexed = [l[idx2] for l in lst[idx1]] + tensor_indexed = tensor[idx1, idx2] + else: + lst_indexed = lst[idx1] + tensor_indexed = tensor[idx1] + numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed)) + + # self.assertRaises(ValueError, lambda: reference[1:9:0]) + # self.assertRaises(ValueError, lambda: reference[1:9:-1]) + + # self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) + # self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) + # self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) + + # self.assertRaises(IndexError, lambda: reference[0.0]) + # self.assertRaises(TypeError, lambda: reference[0.0:2.0]) + # self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) + # self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) + # self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) + # self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) + + # def delitem(): del reference[0] + # self.assertRaises(TypeError, delitem) + + def test_advancedindex(self): + # integer array indexing + + # pick a random valid indexer type + def ri(indices): + choice = random.randint(0, 1) + # TODO: we do not support tuple of list for index now + if choice == 0: + return Tensor(indices) + if choice == 1: + return list(indices) + return tuple(indices) + + def validate_indexing(x): + numpy_testing_assert_equal_helper(x[[0]], consec((1,))) + numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,))) + numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4)) + numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3)) + numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3)) + numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5])) + + def validate_setting(x): + pass + # # TODO: we don't support setitem now + # x[[0]] = -2 + # numpy_testing_assert_equal_helper(x[[0]], np.array([-2])) + # x[[0]] = -1 + # numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1])) + # x[[2, 3, 4]] = 4 + # numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4])) + # x[ri([2, 3, 4]), ] = 3 + # numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3])) + # x[ri([0, 2, 4]), ] = np.array([5, 4, 3]) + # numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3])) + + # Case 1: Purely Integer Array Indexing + reference = consec((10,)) + validate_indexing(reference) + + # setting values + validate_setting(reference) + + # # Tensor with stride != 1 + # # strided is [1, 3, 5, 7] + # reference = consec((10,)) + # strided = np.array(()) + # strided.set_(reference.storage(), storage_offset=0, + # size=torch.Size([4]), stride=[2]) + + # numpy_testing_assert_equal_helper(strided[[0]], np.array([1])) + # numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1])) + # numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7])) + # numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5])) + # numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5])) + # numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ], + # np.array([[5, 3], [1, 7]])) + + # # stride is [4, 8] + # strided = np.array(()) + # strided.set_(reference.storage(), storage_offset=4, + # size=torch.Size([2]), stride=[4]) + # numpy_testing_assert_equal_helper(strided[[0]], np.array([5])) + # numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5])) + # numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9])) + # numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9])) + # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9])) + # numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ], + # np.array([[5, 9], [9, 5]])) + + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)) + numpy_testing_assert_equal_helper( + reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]) + ) + numpy_testing_assert_equal_helper( + reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]) + ) + numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,))) + numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6)) + # # TODO: we don't support list of Tensors as index + # numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2])) + # numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6])) + # numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3])) + + # rows = ri([[0, 0], + # [1, 2]]) + # columns = [0], + # numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1], + # [3, 5]])) + + rows = ri([[0, 0], [1, 2]]) + columns = ri([1, 0]) + numpy_testing_assert_equal_helper( + reference[rows, columns], np.array([[2, 1], [4, 5]]) + ) + rows = ri([[0, 0], [1, 2]]) + columns = ri([[0, 1], [1, 0]]) + numpy_testing_assert_equal_helper( + reference[rows, columns], np.array([[1, 2], [4, 5]]) + ) + + # # setting values + # reference[ri([0]), ri([1])] = -1 + # numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1])) + # reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4]) + # numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], + # np.array([-1, 2, -4])) + # reference[rows, columns] = np.array([[4, 6], [2, 3]]) + # numpy_testing_assert_equal_helper(reference[rows, columns], + # np.array([[4, 6], [2, 3]])) + + # Verify still works with Transposed (i.e. non-contiguous) Tensors + + reference = Tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]).T + + # Transposed: [[0, 4, 8], + # [1, 5, 9], + # [2, 6, 10], + # [3, 7, 11]] + + numpy_testing_assert_equal_helper( + reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]) + ) + numpy_testing_assert_equal_helper( + reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]) + ) + numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0])) + numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6])) + # # TODO: we don't support list of Tensors as index + # numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4])) + # numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7])) + # numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1])) + + # rows = ri([[0, 0], + # [1, 2]]) + # columns = [0], + # numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]])) + + rows = ri([[0, 0], [1, 2]]) + columns = ri([1, 0]) + numpy_testing_assert_equal_helper( + reference[rows, columns], np.array([[4, 0], [5, 2]]) + ) + rows = ri([[0, 0], [1, 3]]) + columns = ri([[0, 1], [1, 2]]) + numpy_testing_assert_equal_helper( + reference[rows, columns], np.array([[0, 4], [5, 11]]) + ) + + # # setting values + # reference[ri([0]), ri([1])] = -1 + # numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], + # np.array([-1])) + # reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4]) + # numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], + # np.array([-1, 2, -4])) + # reference[rows, columns] = np.array([[4, 6], [2, 3]]) + # numpy_testing_assert_equal_helper(reference[rows, columns], + # np.array([[4, 6], [2, 3]])) + + # # stride != 1 + + # # strided is [[1 3 5 7], + # # [9 11 13 15]] + + # reference = torch.arange(0., 24).view(3, 8) + # strided = np.array(()) + # strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), + # stride=[8, 2]) + + # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])], + # np.array([1, 9])) + # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])], + # np.array([3, 11])) + # numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])], + # np.array([1])) + # numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])], + # np.array([15])) + # numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]], + # np.array([1, 7])) + # numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], + # np.array([9, 11, 11, 9, 15])) + # numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + # np.array([1, 3, 9, 9])) + + # rows = ri([[0, 0], + # [1, 1]]) + # columns = [0], + # numpy_testing_assert_equal_helper(strided[rows, columns], + # np.array([[1, 1], [9, 9]])) + + # rows = ri([[0, 1], + # [1, 0]]) + # columns = ri([1, 2]) + # numpy_testing_assert_equal_helper(strided[rows, columns], + # np.array([[3, 13], [11, 5]])) + # rows = ri([[0, 0], + # [1, 1]]) + # columns = ri([[0, 1], + # [1, 2]]) + # numpy_testing_assert_equal_helper(strided[rows, columns], + # np.array([[1, 3], [11, 13]])) + + # # setting values + + # # strided is [[10, 11], + # # [17, 18]] + + # reference = torch.arange(0., 24).view(3, 8) + # strided = np.array(()) + # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + # stride=[7, 1]) + # numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], + # np.array([11])) + # strided[ri([0]), ri([1])] = -1 + # numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])], + # np.array([-1])) + + # reference = torch.arange(0., 24).view(3, 8) + # strided = np.array(()) + # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + # stride=[7, 1]) + # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], + # np.array([11, 17])) + # strided[ri([0, 1]), ri([1, 0])] = np.array([-1, 2]) + # numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])], + # np.array([-1, 2])) + + # reference = torch.arange(0., 24).view(3, 8) + # strided = np.array(()) + # strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + # stride=[7, 1]) + + # rows = ri([[0], + # [1]]) + # columns = ri([[0, 1], + # [0, 1]]) + # numpy_testing_assert_equal_helper(strided[rows, columns], + # np.array([[10, 11], [17, 18]])) + # strided[rows, columns] = np.array([[4, 6], [2, 3]]) + # numpy_testing_assert_equal_helper(strided[rows, columns], + # np.array([[4, 6], [2, 3]])) + + # Tests using less than the number of dims, and ellipsis + + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)) + numpy_testing_assert_equal_helper( + reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]) + ) + numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]])) + numpy_testing_assert_equal_helper( + reference[..., ri([1])], np.array([[2], [4], [6]]) + ) + + # verify too many indices fails + with self.assertRaises(IndexError): + reference[ri([1]), ri([0, 2]), ri([3])] + + # # test invalid index fails + # reference = torch.empty(10) + # # can't test cuda because it is a device assert + # if not reference.is_cuda: + # for err_idx in (10, -11): + # with self.assertRaisesRegex(IndexError, r'out of'): + # reference[err_idx] + # with self.assertRaisesRegex(IndexError, r'out of'): + # reference[torch.LongTensor([err_idx]).to(device)] + # with self.assertRaisesRegex(IndexError, r'out of'): + # reference[[err_idx]] + + # def tensor_indices_to_np(tensor, indices): + # # convert the Torch Tensor to a numpy array + # tensor = tensor.to(device='cpu') + # npt = tensor.numpy() + + # # convert indices + # idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else + # i for i in indices) + + # return npt, idxs + + # def get_numpy(tensor, indices): + # npt, idxs = tensor_indices_to_np(tensor, indices) + + # # index and return as a Torch Tensor + # return np.array(npt[idxs]) + + # def set_numpy(tensor, indices, value): + # if not isinstance(value, int): + # if self.device_type != 'cpu': + # value = value.cpu() + # value = value.numpy() + + # npt, idxs = tensor_indices_to_np(tensor, indices) + # npt[idxs] = value + # return npt + + # def assert_get_eq(tensor, indexer): + # numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer)) + + # def assert_set_eq(tensor, indexer, val): + # pyt = tensor.clone() + # numt = tensor.clone() + # pyt[indexer] = val + # numt = np.array(set_numpy(numt, indexer, val)) + # numpy_testing_assert_equal_helper(pyt, numt) + + # def assert_backward_eq(tensor, indexer): + # cpu = tensor.float().clone().detach().requires_grad_(True) + # outcpu = cpu[indexer] + # gOcpu = torch.rand_like(outcpu) + # outcpu.backward(gOcpu) + # dev = cpu.to(device).detach().requires_grad_(True) + # outdev = dev[indexer] + # outdev.backward(gOcpu.to(device)) + # numpy_testing_assert_equal_helper(cpu.grad, dev.grad) + + # def get_set_tensor(indexed, indexer): + # set_size = indexed[indexer].size() + # set_count = indexed[indexer].numel() + # set_tensor = torch.randperm(set_count).view(set_size).double().to(device) + # return set_tensor + + # # Tensor is 0 1 2 3 4 + # # 5 6 7 8 9 + # # 10 11 12 13 14 + # # 15 16 17 18 19 + # reference = torch.arange(0., 20).view(4, 5) + + # indices_to_test = [ + # # grab the second, fourth columns + # [slice(None), [1, 3]], + + # # first, third rows, + # [[0, 2], slice(None)], + + # # weird shape + # [slice(None), [[0, 1], + # [2, 3]]], + # # negatives + # [[-1], [0]], + # [[0, 2], [-1]], + # [slice(None), [-1]], + # ] + + # # only test dupes on gets + # get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] + + # for indexer in get_indices_to_test: + # assert_get_eq(reference, indexer) + # if self.device_type != 'cpu': + # assert_backward_eq(reference, indexer) + + # for indexer in indices_to_test: + # assert_set_eq(reference, indexer, 44) + # assert_set_eq(reference, + # indexer, + # get_set_tensor(reference, indexer)) + + # reference = torch.arange(0., 160).view(4, 8, 5) + + # indices_to_test = [ + # [slice(None), slice(None), [0, 3, 4]], + # [slice(None), [2, 4, 5, 7], slice(None)], + # [[2, 3], slice(None), slice(None)], + # [slice(None), [0, 2, 3], [1, 3, 4]], + # [slice(None), [0], [1, 2, 4]], + # [slice(None), [0, 1, 3], [4]], + # [slice(None), [[0, 1], [1, 0]], [[2, 3]]], + # [slice(None), [[0, 1], [2, 3]], [[0]]], + # [slice(None), [[5, 6]], [[0, 3], [4, 4]]], + # [[0, 2, 3], [1, 3, 4], slice(None)], + # [[0], [1, 2, 4], slice(None)], + # [[0, 1, 3], [4], slice(None)], + # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + # [[[0, 1], [1, 0]], [[2, 3]], slice(None)], + # [[[0, 1], [2, 3]], [[0]], slice(None)], + # [[[2, 1]], [[0, 3], [4, 4]], slice(None)], + # [[[2]], [[0, 3], [4, 1]], slice(None)], + # # non-contiguous indexing subspace + # [[0, 2, 3], slice(None), [1, 3, 4]], + + # # less dim, ellipsis + # [[0, 2], ], + # [[0, 2], slice(None)], + # [[0, 2], Ellipsis], + # [[0, 2], slice(None), Ellipsis], + # [[0, 2], Ellipsis, slice(None)], + # [[0, 2], [1, 3]], + # [[0, 2], [1, 3], Ellipsis], + # [Ellipsis, [1, 3], [2, 3]], + # [Ellipsis, [2, 3, 4]], + # [Ellipsis, slice(None), [2, 3, 4]], + # [slice(None), Ellipsis, [2, 3, 4]], + + # # ellipsis counts for nothing + # [Ellipsis, slice(None), slice(None), [0, 3, 4]], + # [slice(None), Ellipsis, slice(None), [0, 3, 4]], + # [slice(None), slice(None), Ellipsis, [0, 3, 4]], + # [slice(None), slice(None), [0, 3, 4], Ellipsis], + # [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], + # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], + # ] + + # for indexer in indices_to_test: + # assert_get_eq(reference, indexer) + # assert_set_eq(reference, indexer, 212) + # assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) + # if torch.cuda.is_available(): + # assert_backward_eq(reference, indexer) + + # reference = torch.arange(0., 1296).view(3, 9, 8, 6) + + # indices_to_test = [ + # [slice(None), slice(None), slice(None), [0, 3, 4]], + # [slice(None), slice(None), [2, 4, 5, 7], slice(None)], + # [slice(None), [2, 3], slice(None), slice(None)], + # [[1, 2], slice(None), slice(None), slice(None)], + # [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], + # [slice(None), slice(None), [0], [1, 2, 4]], + # [slice(None), slice(None), [0, 1, 3], [4]], + # [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], + # [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], + # [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], + # [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], + # [slice(None), [0], [1, 2, 4], slice(None)], + # [slice(None), [0, 1, 3], [4], slice(None)], + # [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], + # [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], + # [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], + # [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], + # [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], + # [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], + # [[0], [1, 2, 4], slice(None), slice(None)], + # [[0, 1, 2], [4], slice(None), slice(None)], + # [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], + # [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], + # [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], + # [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], + # [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], + # [slice(None), [2, 3, 4], [1, 3, 4], [4]], + # [slice(None), [0, 1, 3], [4], [1, 3, 4]], + # [slice(None), [6], [0, 2, 3], [1, 3, 4]], + # [slice(None), [2, 3, 5], [3], [4]], + # [slice(None), [0], [4], [1, 3, 4]], + # [slice(None), [6], [0, 2, 3], [1]], + # [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], + # [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], + # [[2, 0, 1], [1, 2, 3], [4], slice(None)], + # [[0, 1, 2], [4], [1, 3, 4], slice(None)], + # [[0], [0, 2, 3], [1, 3, 4], slice(None)], + # [[0, 2, 1], [3], [4], slice(None)], + # [[0], [4], [1, 3, 4], slice(None)], + # [[1], [0, 2, 3], [1], slice(None)], + # [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], + + # # less dim, ellipsis + # [Ellipsis, [0, 3, 4]], + # [Ellipsis, slice(None), [0, 3, 4]], + # [Ellipsis, slice(None), slice(None), [0, 3, 4]], + # [slice(None), Ellipsis, [0, 3, 4]], + # [slice(None), slice(None), Ellipsis, [0, 3, 4]], + # [slice(None), [0, 2, 3], [1, 3, 4]], + # [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], + # [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], + # [[0], [1, 2, 4]], + # [[0], [1, 2, 4], slice(None)], + # [[0], [1, 2, 4], Ellipsis], + # [[0], [1, 2, 4], Ellipsis, slice(None)], + # [[1], ], + # [[0, 2, 1], [3], [4]], + # [[0, 2, 1], [3], [4], slice(None)], + # [[0, 2, 1], [3], [4], Ellipsis], + # [Ellipsis, [0, 2, 1], [3], [4]], + # ] + + # for indexer in indices_to_test: + # assert_get_eq(reference, indexer) + # assert_set_eq(reference, indexer, 1333) + # assert_set_eq(reference, indexer, get_set_tensor(reference, indexer)) + # indices_to_test += [ + # [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], + # [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], + # ] + # for indexer in indices_to_test: + # assert_get_eq(reference, indexer) + # assert_set_eq(reference, indexer, 1333) + # if self.device_type != 'cpu': + # assert_backward_eq(reference, indexer) + + # def test_advancedindex_big(self): + # reference = Tensor.arange(123344) + # numpy_testing_assert_equal_helper(reference[[0, 123, 44488, 68807, 123343],], np.array([0, 123, 44488, 68807, 123343])) + + # def test_set_item_to_scalar_tensor(self): + # m = random.randint(1, 10) + # n = random.randint(1, 10) + # z = torch.randn([m, n]) + # a = 1.0 + # w = np.array(a, requires_grad=True) + # z[:, 0] = w + # z.sum().backward() + # numpy_testing_assert_equal_helper(w.grad, m * a) + + def test_single_int(self): + v = Tensor.randn(5, 7, 3) + numpy_testing_assert_equal_helper(v[4].shape, (7, 3)) + + def test_multiple_int(self): + v = Tensor.randn(5, 7, 3) + numpy_testing_assert_equal_helper(v[4].shape, (7, 3)) + numpy_testing_assert_equal_helper(v[4, :, 1].shape, (7,)) + + def test_none(self): + v = Tensor.randn(5, 7, 3) + numpy_testing_assert_equal_helper(v[None].shape, (1, 5, 7, 3)) + numpy_testing_assert_equal_helper(v[:, None].shape, (5, 1, 7, 3)) + numpy_testing_assert_equal_helper(v[:, None, None].shape, (5, 1, 1, 7, 3)) + numpy_testing_assert_equal_helper(v[..., None].shape, (5, 7, 3, 1)) + + def test_step(self): + v = Tensor.arange(10) + numpy_testing_assert_equal_helper(v[::1], v) + numpy_testing_assert_equal_helper(v[::2], [0, 2, 4, 6, 8]) + numpy_testing_assert_equal_helper(v[::3], [0, 3, 6, 9]) + numpy_testing_assert_equal_helper(v[::11], [0]) + numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5]) + + # def test_step_assignment(self): + # v = torch.zeros(4, 4) + # v[0, 1::2] = np.array([3., 4.]) + # numpy_testing_assert_equal_helper(v[0].tolist(), [0, 3, 0, 4]) + # numpy_testing_assert_equal_helper(v[1:].sum(), 0) + + # def test_bool_indices(self): + # v = Tensor.randn(5, 7, 3) + # boolIndices = np.array([True, False, True, True, False], dtype=bool) + # numpy_testing_assert_equal_helper(v[boolIndices].shape, (3, 7, 3)) + # numpy_testing_assert_equal_helper(v[boolIndices], Tensor.stack([v[0], v[2], v[3]])) + + # v = np.array([True, False, True], dtype=torch.bool) + # boolIndices = np.array([True, False, False], dtype=torch.bool) + # uint8Indices = np.array([1, 0, 0], dtype=torch.uint8) + # with warnings.catch_warnings(record=True) as w: + # numpy_testing_assert_equal_helper(v[boolIndices].shape, v[uint8Indices].shape) + # numpy_testing_assert_equal_helper(v[boolIndices], v[uint8Indices]) + # numpy_testing_assert_equal_helper(v[boolIndices], tensor([True], dtype=torch.bool)) + # numpy_testing_assert_equal_helper(len(w), 2) + + # def test_bool_indices_accumulate(self): + # mask = torch.zeros(size=(10, ), dtype=torch.bool) + # y = torch.ones(size=(10, 10)) + # y.index_put_((mask, ), y[mask], accumulate=True) + # numpy_testing_assert_equal_helper(y, torch.ones(size=(10, 10))) + + # def test_multiple_bool_indices(self): + # v = torch.randn(5, 7, 3) + # # note: these broadcast together and are transposed to the first dim + # mask1 = np.array([1, 0, 1, 1, 0], dtype=torch.bool) + # mask2 = np.array([1, 1, 1], dtype=torch.bool) + # numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7)) + + # def test_byte_mask(self): + # v = torch.randn(5, 7, 3) + # mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) + # with warnings.catch_warnings(record=True) as w: + # numpy_testing_assert_equal_helper(v[mask].shape, (3, 7, 3)) + # numpy_testing_assert_equal_helper(v[mask], torch.stack([v[0], v[2], v[3]])) + # numpy_testing_assert_equal_helper(len(w), 2) + + # v = np.array([1.]) + # numpy_testing_assert_equal_helper(v[v == 0], np.array([])) + + # def test_byte_mask_accumulate(self): + # mask = torch.zeros(size=(10, ), dtype=torch.uint8) + # y = torch.ones(size=(10, 10)) + # with warnings.catch_warnings(record=True) as w: + # warnings.simplefilter("always") + # y.index_put_((mask, ), y[mask], accumulate=True) + # numpy_testing_assert_equal_helper(y, torch.ones(size=(10, 10))) + # numpy_testing_assert_equal_helper(len(w), 2) + + # def test_index_put_accumulate_large_tensor(self): + # # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). + # N = (1 << 31) + 5 + # dt = torch.int8 + # a = torch.ones(N, dtype=dt) + # indices = np.array([-2, 0, -2, -1, 0, -1, 1], dtype=torch.long) + # values = np.array([6, 5, 6, 6, 5, 7, 11], dtype=dt) + + # a.index_put_((indices, ), values, accumulate=True) + + # numpy_testing_assert_equal_helper(a[0], 11) + # numpy_testing_assert_equal_helper(a[1], 12) + # numpy_testing_assert_equal_helper(a[2], 1) + # numpy_testing_assert_equal_helper(a[-3], 1) + # numpy_testing_assert_equal_helper(a[-2], 13) + # numpy_testing_assert_equal_helper(a[-1], 14) + + # a = torch.ones((2, N), dtype=dt) + # indices0 = np.array([0, -1, 0, 1], dtype=torch.long) + # indices1 = np.array([-2, -1, 0, 1], dtype=torch.long) + # values = np.array([12, 13, 10, 11], dtype=dt) + + # a.index_put_((indices0, indices1), values, accumulate=True) + + # numpy_testing_assert_equal_helper(a[0, 0], 11) + # numpy_testing_assert_equal_helper(a[0, 1], 1) + # numpy_testing_assert_equal_helper(a[1, 0], 1) + # numpy_testing_assert_equal_helper(a[1, 1], 12) + # numpy_testing_assert_equal_helper(a[:, 2], torch.ones(2, dtype=torch.int8)) + # numpy_testing_assert_equal_helper(a[:, -3], torch.ones(2, dtype=torch.int8)) + # numpy_testing_assert_equal_helper(a[0, -2], 13) + # numpy_testing_assert_equal_helper(a[1, -2], 1) + # numpy_testing_assert_equal_helper(a[-1, -1], 14) + # numpy_testing_assert_equal_helper(a[0, -1], 1) + + # def test_index_put_accumulate_expanded_values(self): + # # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227 + # # and verifies consistency with CPU result + # t = torch.zeros((5, 2)) + # t_dev = t.to(device) + # indices = [ + # np.array([0, 1, 2, 3]), + # np.array([1, ]), + # ] + # indices_dev = [i.to(device) for i in indices] + # values0d = np.array(1.0) + # values1d = np.array([1.0, ]) + + # out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) + # out_cpu = t.index_put_(indices, values0d, accumulate=True) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) + # out_cpu = t.index_put_(indices, values1d, accumulate=True) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # t = torch.zeros(4, 3, 2) + # t_dev = t.to(device) + + # indices = [ + # np.array([0, ]), + # torch.arange(3)[:, None], + # torch.arange(2)[None, :], + # ] + # indices_dev = [i.to(device) for i in indices] + # values1d = np.array([-1.0, -2.0]) + # values2d = np.array([[-1.0, -2.0], ]) + + # out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) + # out_cpu = t.index_put_(indices, values1d, accumulate=True) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) + # out_cpu = t.index_put_(indices, values2d, accumulate=True) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # def test_index_put_accumulate_non_contiguous(self): + # t = torch.zeros((5, 2, 2)) + # t_dev = t.to(device) + # t1 = t_dev[:, 0, :] + # t2 = t[:, 0, :] + # self.assertTrue(not t1.is_contiguous()) + # self.assertTrue(not t2.is_contiguous()) + + # indices = [np.array([0, 1]), ] + # indices_dev = [i.to(device) for i in indices] + # value = torch.randn(2, 2) + # out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True) + # out_cpu = t2.index_put_(indices, value, accumulate=True) + # self.assertTrue(not t1.is_contiguous()) + # self.assertTrue(not t2.is_contiguous()) + + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # def test_index_put_accumulate_with_optional_tensors(self): + # # TODO: replace with a better solution. + # # Currently, here using torchscript to put None into indices. + # # on C++ it gives indices as a list of 2 optional tensors: first is null and + # # the second is a valid tensor. + # @torch.jit.script + # def func(x, i, v): + # idx = [None, i] + # x.index_put_(idx, v, accumulate=True) + # return x + + # n = 4 + # t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) + # t_dev = t.to(device) + # indices = np.array([1, 0]) + # indices_dev = indices.to(device) + # value0d = np.array(10.0) + # value1d = np.array([1.0, 2.0]) + + # out_cuda = func(t_dev, indices_dev, value0d.cuda()) + # out_cpu = func(t, indices, value0d) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # out_cuda = func(t_dev, indices_dev, value1d.cuda()) + # out_cpu = func(t, indices, value1d) + # numpy_testing_assert_equal_helper(out_cuda.cpu(), out_cpu) + + # def test_index_put_accumulate_duplicate_indices(self): + # for i in range(1, 512): + # # generate indices by random walk, this will create indices with + # # lots of duplicates interleaved with each other + # delta = torch.empty(i, dtype=torch.double).uniform_(-1, 1) + # indices = delta.cumsum(0).long() + + # input = torch.randn(indices.abs().max() + 1) + # values = torch.randn(indices.size(0)) + # output = input.index_put((indices,), values, accumulate=True) + + # input_list = input.tolist() + # indices_list = indices.tolist() + # values_list = values.tolist() + # for i, v in zip(indices_list, values_list): + # input_list[i] += v + + # numpy_testing_assert_equal_helper(output, input_list) + + # def test_index_ind_dtype(self): + # x = torch.randn(4, 4) + # ind_long = torch.randint(4, (4,), dtype=torch.long) + # ind_int = ind_long.int() + # src = torch.randn(4) + # ref = x[ind_long, ind_long] + # res = x[ind_int, ind_int] + # numpy_testing_assert_equal_helper(ref, res) + # ref = x[ind_long, :] + # res = x[ind_int, :] + # numpy_testing_assert_equal_helper(ref, res) + # ref = x[:, ind_long] + # res = x[:, ind_int] + # numpy_testing_assert_equal_helper(ref, res) + # # no repeating indices for index_put + # ind_long = torch.arange(4, dtype=torch.long) + # ind_int = ind_long.int() + # for accum in (True, False): + # inp_ref = x.clone() + # inp_res = x.clone() + # torch.index_put_(inp_ref, (ind_long, ind_long), src, accum) + # torch.index_put_(inp_res, (ind_int, ind_int), src, accum) + # numpy_testing_assert_equal_helper(inp_ref, inp_res) + + # def test_index_put_accumulate_empty(self): + # # Regression test for https://github.com/pytorch/pytorch/issues/94667 + # input = torch.rand([], dtype=torch.float32) + # with self.assertRaises(RuntimeError): + # input.index_put([], np.array([1.0]), True) + + # def test_multiple_byte_mask(self): + # v = torch.randn(5, 7, 3) + # # note: these broadcast together and are transposed to the first dim + # mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) + # mask2 = torch.ByteTensor([1, 1, 1]).to(device) + # with warnings.catch_warnings(record=True) as w: + # warnings.simplefilter("always") + # numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7)) + # numpy_testing_assert_equal_helper(len(w), 2) + + # def test_byte_mask2d(self): + # v = torch.randn(5, 7, 3) + # c = torch.randn(5, 7) + # num_ones = (c > 0).sum() + # r = v[c > 0] + # numpy_testing_assert_equal_helper(r.shape, (num_ones, 3)) + + # def test_jit_indexing(self): + # def fn1(x): + # x[x < 50] = 1.0 + # return x + + # def fn2(x): + # x[0:50] = 1.0 + # return x + + # scripted_fn1 = torch.jit.script(fn1) + # scripted_fn2 = torch.jit.script(fn2) + # data = torch.arange(100, dtype=torch.float) + # out = scripted_fn1(data.detach().clone()) + # ref = np.array(np.concatenate((np.ones(50), np.arange(50, 100))), dtype=torch.float) + # numpy_testing_assert_equal_helper(out, ref) + # out = scripted_fn2(data.detach().clone()) + # numpy_testing_assert_equal_helper(out, ref) + + # def test_int_indices(self): + # v = torch.randn(5, 7, 3) + # numpy_testing_assert_equal_helper(v[[0, 4, 2]].shape, (3, 7, 3)) + # numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3)) + # numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) + + # def test_index_put_src_datatype(self, dtype): + # src = torch.ones(3, 2, 4, dtype=dtype) + # vals = torch.ones(3, 2, 4, dtype=dtype) + # indices = (np.array([0, 2, 1]),) + # res = src.index_put_(indices, vals, accumulate=True) + # numpy_testing_assert_equal_helper(res.shape, src.shape) + + # def test_index_src_datatype(self, dtype): + # src = torch.ones(3, 2, 4, dtype=dtype) + # # test index + # res = src[[0, 2, 1], :, :] + # numpy_testing_assert_equal_helper(res.shape, src.shape) + # # test index_put, no accum + # src[[0, 2, 1], :, :] = res + # numpy_testing_assert_equal_helper(res.shape, src.shape) + + # def test_int_indices2d(self): + # # From the NumPy indexing example + # x = torch.arange(0, 12).view(4, 3) + # rows = np.array([[0, 0], [3, 3]]) + # columns = np.array([[0, 2], [0, 2]]) + # numpy_testing_assert_equal_helper(x[rows, columns].tolist(), [[0, 2], [9, 11]]) + + # def test_int_indices_broadcast(self): + # # From the NumPy indexing example + # x = torch.arange(0, 12).view(4, 3) + # rows = np.array([0, 3]) + # columns = np.array([0, 2]) + # result = x[rows[:, None], columns] + # numpy_testing_assert_equal_helper(result.tolist(), [[0, 2], [9, 11]]) + + # def test_empty_index(self): + # x = torch.arange(0, 12).view(4, 3) + # idx = np.array([], dtype=torch.long) + # numpy_testing_assert_equal_helper(x[idx].numel(), 0) + + # # empty assignment should have no effect but not throw an exception + # y = x.clone() + # y[idx] = -1 + # numpy_testing_assert_equal_helper(x, y) + + # mask = torch.zeros(4, 3).bool() + # y[mask] = -1 + # numpy_testing_assert_equal_helper(x, y) + + # def test_empty_ndim_index(self): + # x = torch.randn(5) + # numpy_testing_assert_equal_helper(torch.empty(0, 2), x[torch.empty(0, 2, dtype=torch.int64)]) + + # x = torch.randn(2, 3, 4, 5) + # numpy_testing_assert_equal_helper(torch.empty(2, 0, 6, 4, 5), + # x[:, torch.empty(0, 6, dtype=torch.int64)]) + + # x = torch.empty(10, 0) + # numpy_testing_assert_equal_helper(x[[1, 2]].shape, (2, 0)) + # numpy_testing_assert_equal_helper(x[[], []].shape, (0,)) + # with self.assertRaisesRegex(IndexError, 'for dimension with size 0'): + # x[:, [0, 1]] + + # def test_empty_ndim_index_bool(self): + # x = torch.randn(5) + # self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8)]) + + # def test_empty_slice(self): + # x = torch.randn(2, 3, 4, 5) + # y = x[:, :, :, 1] + # z = y[:, 1:1, :] + # numpy_testing_assert_equal_helper((2, 0, 4), z.shape) + # # this isn't technically necessary, but matches NumPy stride calculations. + # numpy_testing_assert_equal_helper((60, 20, 5), z.stride()) + # self.assertTrue(z.is_contiguous()) + + # def test_index_getitem_copy_bools_slices(self): + # true = np.array(1, dtype=torch.uint8) + # false = np.array(0, dtype=torch.uint8) + + # tensors = [torch.randn(2, 3), np.array(3.)] + + # for a in tensors: + # self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) + # numpy_testing_assert_equal_helper(torch.empty(0, *a.shape), a[False]) + # self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) + # numpy_testing_assert_equal_helper(torch.empty(0, *a.shape), a[false]) + # numpy_testing_assert_equal_helper(a.data_ptr(), a[None].data_ptr()) + # numpy_testing_assert_equal_helper(a.data_ptr(), a[...].data_ptr()) + + # def test_index_setitem_bools_slices(self): + # true = np.array(1, dtype=torch.uint8) + # false = np.array(0, dtype=torch.uint8) + + # tensors = [torch.randn(2, 3), np.array(3)] + + # for a in tensors: + # # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s + # # (some of these ops already prefix a 1 to the size) + # neg_ones = torch.ones_like(a) * -1 + # neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) + # a[True] = neg_ones_expanded + # numpy_testing_assert_equal_helper(a, neg_ones) + # a[False] = 5 + # numpy_testing_assert_equal_helper(a, neg_ones) + # a[true] = neg_ones_expanded * 2 + # numpy_testing_assert_equal_helper(a, neg_ones * 2) + # a[false] = 5 + # numpy_testing_assert_equal_helper(a, neg_ones * 2) + # a[None] = neg_ones_expanded * 3 + # numpy_testing_assert_equal_helper(a, neg_ones * 3) + # a[...] = neg_ones_expanded * 4 + # numpy_testing_assert_equal_helper(a, neg_ones * 4) + # if a.dim() == 0: + # with self.assertRaises(IndexError): + # a[:] = neg_ones_expanded * 5 + + # def test_index_scalar_with_bool_mask(self): + # a = np.array(1) + # uintMask = np.array(True, dtype=torch.uint8) + # boolMask = np.array(True, dtype=torch.bool) + # numpy_testing_assert_equal_helper(a[uintMask], a[boolMask]) + # numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype) + + # a = np.array(True, dtype=torch.bool) + # numpy_testing_assert_equal_helper(a[uintMask], a[boolMask]) + # numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype) + + # def test_setitem_expansion_error(self): + # true = np.array(True) + # a = torch.randn(2, 3) + # # check prefix with non-1s doesn't work + # a_expanded = a.expand(torch.Size([5, 1]) + a.size()) + # # NumPy: ValueError + # with self.assertRaises(RuntimeError): + # a[True] = a_expanded + # with self.assertRaises(RuntimeError): + # a[true] = a_expanded + + # def test_getitem_scalars(self): + # zero = np.array(0, dtype=torch.int64) + # one = np.array(1, dtype=torch.int64) + + # # non-scalar indexed with scalars + # a = torch.randn(2, 3) + # numpy_testing_assert_equal_helper(a[0], a[zero]) + # numpy_testing_assert_equal_helper(a[0][1], a[zero][one]) + # numpy_testing_assert_equal_helper(a[0, 1], a[zero, one]) + # numpy_testing_assert_equal_helper(a[0, one], a[zero, 1]) + + # # indexing by a scalar should slice (not copy) + # numpy_testing_assert_equal_helper(a[0, 1].data_ptr(), a[zero, one].data_ptr()) + # numpy_testing_assert_equal_helper(a[1].data_ptr(), a[one.int()].data_ptr()) + # numpy_testing_assert_equal_helper(a[1].data_ptr(), a[one.short()].data_ptr()) + + # # scalar indexed with scalar + # r = torch.randn(()) + # with self.assertRaises(IndexError): + # r[:] + # with self.assertRaises(IndexError): + # r[zero] + # numpy_testing_assert_equal_helper(r, r[...]) + + # def test_setitem_scalars(self): + # zero = np.array(0, dtype=torch.int64) + + # # non-scalar indexed with scalars + # a = torch.randn(2, 3) + # a_set_with_number = a.clone() + # a_set_with_scalar = a.clone() + # b = torch.randn(3) + + # a_set_with_number[0] = b + # a_set_with_scalar[zero] = b + # numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar) + # a[1, zero] = 7.7 + # numpy_testing_assert_equal_helper(7.7, a[1, 0]) + + # # scalar indexed with scalars + # r = torch.randn(()) + # with self.assertRaises(IndexError): + # r[:] = 8.8 + # with self.assertRaises(IndexError): + # r[zero] = 8.8 + # r[...] = 9.9 + # numpy_testing_assert_equal_helper(9.9, r) + + # def test_basic_advanced_combined(self): + # # From the NumPy indexing example + # x = torch.arange(0, 12).view(4, 3) + # numpy_testing_assert_equal_helper(x[1:2, 1:3], x[1:2, [1, 2]]) + # numpy_testing_assert_equal_helper(x[1:2, 1:3].tolist(), [[4, 5]]) + + # # Check that it is a copy + # unmodified = x.clone() + # x[1:2, [1, 2]].zero_() + # numpy_testing_assert_equal_helper(x, unmodified) + + # # But assignment should modify the original + # unmodified = x.clone() + # x[1:2, [1, 2]] = 0 + # self.assertNotEqual(x, unmodified) + + # def test_int_assignment(self): + # x = torch.arange(0, 4).view(2, 2) + # x[1] = 5 + # numpy_testing_assert_equal_helper(x.tolist(), [[0, 1], [5, 5]]) + + # x = torch.arange(0, 4).view(2, 2) + # x[1] = torch.arange(5, 7) + # numpy_testing_assert_equal_helper(x.tolist(), [[0, 1], [5, 6]]) + + # def test_byte_tensor_assignment(self): + # x = torch.arange(0., 16).view(4, 4) + # b = torch.ByteTensor([True, False, True, False]).to(device) + # value = np.array([3., 4., 5., 6.]) + + # with warnings.catch_warnings(record=True) as w: + # x[b] = value + # numpy_testing_assert_equal_helper(len(w), 1) + + # numpy_testing_assert_equal_helper(x[0], value) + # numpy_testing_assert_equal_helper(x[1], torch.arange(4., 8)) + # numpy_testing_assert_equal_helper(x[2], value) + # numpy_testing_assert_equal_helper(x[3], torch.arange(12., 16)) + + # def test_variable_slicing(self): + # x = torch.arange(0, 16).view(4, 4) + # indices = torch.IntTensor([0, 1]).to(device) + # i, j = indices + # numpy_testing_assert_equal_helper(x[i:j], x[0:1]) + + # def test_ellipsis_tensor(self): + # x = torch.arange(0, 9).view(3, 3) + # idx = np.array([0, 2]) + # numpy_testing_assert_equal_helper(x[..., idx].tolist(), [[0, 2], + # [3, 5], + # [6, 8]]) + # numpy_testing_assert_equal_helper(x[idx, ...].tolist(), [[0, 1, 2], + # [6, 7, 8]]) + + # def test_unravel_index_errors(self): + # with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): + # torch.unravel_index( + # np.array(0.5), + # (2, 2)) + + # with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"): + # torch.unravel_index( + # np.array([]), + # (10, 3, 5)) + + # with self.assertRaisesRegex(TypeError, r"expected 'shape' to be int or sequence"): + # torch.unravel_index( + # np.array([1], dtype=torch.int64), + # np.array([1, 2, 3])) + + # with self.assertRaisesRegex(TypeError, r"expected 'shape' sequence to only contain ints"): + # torch.unravel_index( + # np.array([1], dtype=torch.int64), + # (1, 2, 2.0)) + + # with self.assertRaisesRegex(ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"): + # torch.unravel_index( + # np.array(0), + # (2, -3)) + + # def test_invalid_index(self): + # x = torch.arange(0, 16).view(4, 4) + # self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"]) + + # def test_out_of_bound_index(self): + # x = torch.arange(0, 100).view(2, 5, 10) + # self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5]) + # self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5]) + # self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10', + # lambda: x[0, 1, 15]) + # self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10', + # lambda: x[:, :, 12]) + + # def test_zero_dim_index(self): + # x = np.array(10) + # numpy_testing_assert_equal_helper(x, x.item()) + + # def runner(): + # print(x[0]) + # return x[0] + + # self.assertRaisesRegex(IndexError, 'invalid index', runner) + + # def test_invalid_device(self): + # idx = np.array([0, 1]) + # b = torch.zeros(5) + # c = np.array([1., 2.], device="cpu") + + # for accumulate in [True, False]: + # self.assertRaises(RuntimeError, lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate)) + + # def test_cpu_indices(self): + # idx = np.array([0, 1]) + # b = torch.zeros(2) + # x = torch.ones(10) + # x[idx] = b # index_put_ + # ref = torch.ones(10) + # ref[:2] = 0 + # numpy_testing_assert_equal_helper(x, ref) + # out = x[idx] # index + # numpy_testing_assert_equal_helper(out, torch.zeros(2)) + + # def test_take_along_dim(self, dtype): + # def _test_against_numpy(t, indices, dim): + # actual = torch.take_along_dim(t, indices, dim=dim) + # t_np = t.cpu().numpy() + # indices_np = indices.cpu().numpy() + # expected = np.take_along_axis(t_np, indices_np, axis=dim) + # numpy_testing_assert_equal_helper(actual, expected) + + # for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]: + # for noncontiguous in [True, False]: + # t = make_tensor(shape, dtype=dtype, noncontiguous=noncontiguous) + # for dim in list(range(t.ndim)) + [None]: + # if dim is None: + # indices = torch.argsort(t.view(-1)) + # else: + # indices = torch.argsort(t, dim=dim) + + # _test_against_numpy(t, indices, dim) + + # # test broadcasting + # t = torch.ones((3, 4, 1)) + # indices = torch.ones((1, 2, 5), dtype=torch.long) + + # _test_against_numpy(t, indices, 1) + + # # test empty indices + # t = torch.ones((3, 4, 5)) + # indices = torch.ones((3, 0, 5), dtype=torch.long) + + # _test_against_numpy(t, indices, 1) + + # def test_take_along_dim_invalid(self, dtype): + # shape = (2, 3, 1, 4) + # dim = 0 + # t = make_tensor(shape, dtype=dtype) + # indices = torch.argsort(t, dim=dim) + + # # dim of `t` and `indices` does not match + # with self.assertRaisesRegex(RuntimeError, + # "input and indices should have the same number of dimensions"): + # torch.take_along_dim(t, indices[0], dim=0) + + # # invalid `indices` dtype + # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): + # torch.take_along_dim(t, indices.to(torch.bool), dim=0) + + # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): + # torch.take_along_dim(t, indices.to(torch.float), dim=0) + + # with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"): + # torch.take_along_dim(t, indices.to(torch.int32), dim=0) + + # # invalid axis + # with self.assertRaisesRegex(IndexError, "Dimension out of range"): + # torch.take_along_dim(t, indices, dim=-7) + + # with self.assertRaisesRegex(IndexError, "Dimension out of range"): + # torch.take_along_dim(t, indices, dim=7) + + # def test_gather_take_along_dim_cross_device(self, dtype): + # shape = (2, 3, 1, 4) + # dim = 0 + # t = make_tensor(shape, dtype=dtype) + # indices = torch.argsort(t, dim=dim) + + # with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + # torch.gather(t, 0, indices.cpu()) + + # with self.assertRaisesRegex(RuntimeError, + # r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"): + # torch.take_along_dim(t, indices.cpu(), dim=0) + + # with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): + # torch.gather(t.cpu(), 0, indices) + + # with self.assertRaisesRegex(RuntimeError, + # r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"): + # torch.take_along_dim(t.cpu(), indices, dim=0) + + # def test_cuda_broadcast_index_use_deterministic_algorithms(self): + # with DeterministicGuard(True): + # idx1 = np.array([0]) + # idx2 = np.array([2, 6]) + # idx3 = np.array([1, 5, 7]) + + # tensor_a = torch.rand(13, 11, 12, 13, 12).cpu() + # tensor_b = tensor_a.to(device=device) + # tensor_a[idx1] = 1.0 + # tensor_a[idx1, :, idx2, idx2, :] = 2.0 + # tensor_a[:, idx1, idx3, :, idx3] = 3.0 + # tensor_b[idx1] = 1.0 + # tensor_b[idx1, :, idx2, idx2, :] = 2.0 + # tensor_b[:, idx1, idx3, :, idx3] = 3.0 + # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) + + # tensor_a = torch.rand(10, 11).cpu() + # tensor_b = tensor_a.to(device=device) + # tensor_a[idx3] = 1.0 + # tensor_a[idx2, :] = 2.0 + # tensor_a[:, idx2] = 3.0 + # tensor_a[:, idx1] = 4.0 + # tensor_b[idx3] = 1.0 + # tensor_b[idx2, :] = 2.0 + # tensor_b[:, idx2] = 3.0 + # tensor_b[:, idx1] = 4.0 + # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) + + # tensor_a = torch.rand(10, 10).cpu() + # tensor_b = tensor_a.to(device=device) + # tensor_a[[8]] = 1.0 + # tensor_b[[8]] = 1.0 + # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) + + # tensor_a = torch.rand(10).cpu() + # tensor_b = tensor_a.to(device=device) + # tensor_a[6] = 1.0 + # tensor_b[6] = 1.0 + # numpy_testing_assert_equal_helper(tensor_a, tensor_b.cpu()) class TestNumpy(unittest.TestCase): - # def test_index_no_floats(self): - # a = Tensor([[[5.]]]) + # def test_index_no_floats(self): + # a = Tensor([[[5.]]]) - # self.assertRaises(IndexError, lambda: a[0.0]) - # self.assertRaises(IndexError, lambda: a[0, 0.0]) - # self.assertRaises(IndexError, lambda: a[0.0, 0]) - # self.assertRaises(IndexError, lambda: a[0.0, :]) - # self.assertRaises(IndexError, lambda: a[:, 0.0]) - # self.assertRaises(IndexError, lambda: a[:, 0.0, :]) - # self.assertRaises(IndexError, lambda: a[0.0, :, :]) - # self.assertRaises(IndexError, lambda: a[0, 0, 0.0]) - # self.assertRaises(IndexError, lambda: a[0.0, 0, 0]) - # self.assertRaises(IndexError, lambda: a[0, 0.0, 0]) - # self.assertRaises(IndexError, lambda: a[-1.4]) - # self.assertRaises(IndexError, lambda: a[0, -1.4]) - # self.assertRaises(IndexError, lambda: a[-1.4, 0]) - # self.assertRaises(IndexError, lambda: a[-1.4, :]) - # self.assertRaises(IndexError, lambda: a[:, -1.4]) - # self.assertRaises(IndexError, lambda: a[:, -1.4, :]) - # self.assertRaises(IndexError, lambda: a[-1.4, :, :]) - # self.assertRaises(IndexError, lambda: a[0, 0, -1.4]) - # self.assertRaises(IndexError, lambda: a[-1.4, 0, 0]) - # self.assertRaises(IndexError, lambda: a[0, -1.4, 0]) - # # self.assertRaises(IndexError, lambda: a[0.0:, 0.0]) - # # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:]) + # self.assertRaises(IndexError, lambda: a[0.0]) + # self.assertRaises(IndexError, lambda: a[0, 0.0]) + # self.assertRaises(IndexError, lambda: a[0.0, 0]) + # self.assertRaises(IndexError, lambda: a[0.0, :]) + # self.assertRaises(IndexError, lambda: a[:, 0.0]) + # self.assertRaises(IndexError, lambda: a[:, 0.0, :]) + # self.assertRaises(IndexError, lambda: a[0.0, :, :]) + # self.assertRaises(IndexError, lambda: a[0, 0, 0.0]) + # self.assertRaises(IndexError, lambda: a[0.0, 0, 0]) + # self.assertRaises(IndexError, lambda: a[0, 0.0, 0]) + # self.assertRaises(IndexError, lambda: a[-1.4]) + # self.assertRaises(IndexError, lambda: a[0, -1.4]) + # self.assertRaises(IndexError, lambda: a[-1.4, 0]) + # self.assertRaises(IndexError, lambda: a[-1.4, :]) + # self.assertRaises(IndexError, lambda: a[:, -1.4]) + # self.assertRaises(IndexError, lambda: a[:, -1.4, :]) + # self.assertRaises(IndexError, lambda: a[-1.4, :, :]) + # self.assertRaises(IndexError, lambda: a[0, 0, -1.4]) + # self.assertRaises(IndexError, lambda: a[-1.4, 0, 0]) + # self.assertRaises(IndexError, lambda: a[0, -1.4, 0]) + # # self.assertRaises(IndexError, lambda: a[0.0:, 0.0]) + # # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:]) - def test_none_index(self): - # `None` index adds newaxis - a = Tensor([1, 2, 3]) - numpy_testing_assert_equal_helper(a[None].ndim, a.ndim+1) + def test_none_index(self): + # `None` index adds newaxis + a = Tensor([1, 2, 3]) + numpy_testing_assert_equal_helper(a[None].ndim, a.ndim + 1) - def test_empty_tuple_index(self): - # Empty tuple index creates a view - a = Tensor([1, 2, 3]) - numpy_testing_assert_equal_helper(a[()], a) - # # TODO: what's our equivalent test? just is? - # numpy_testing_assert_equal_helper(a[()].data_ptr(), a.data_ptr()) + def test_empty_tuple_index(self): + # Empty tuple index creates a view + a = Tensor([1, 2, 3]) + numpy_testing_assert_equal_helper(a[()], a) + # # TODO: what's our equivalent test? just is? + # numpy_testing_assert_equal_helper(a[()].data_ptr(), a.data_ptr()) - # def test_empty_fancy_index(self): - # # Empty list index creates an empty array - # a = Tensor([1, 2, 3]) - # numpy_testing_assert_equal_helper(a[[]], np.array([])) + # def test_empty_fancy_index(self): + # # Empty list index creates an empty array + # a = Tensor([1, 2, 3]) + # numpy_testing_assert_equal_helper(a[[]], np.array([])) - # b = Tensor([]).long() - # numpy_testing_assert_equal_helper(a[[]], np.array([])) + # b = Tensor([]).long() + # numpy_testing_assert_equal_helper(a[[]], np.array([])) - # b = Tensor([]).float() - # self.assertRaises(IndexError, lambda: a[b]) + # b = Tensor([]).float() + # self.assertRaises(IndexError, lambda: a[b]) -# def test_ellipsis_index(self): -# a = tensor([[1, 2, 3], -# [4, 5, 6], -# [7, 8, 9]]) -# self.assertIsNot(a[...], a) -# numpy_testing_assert_equal_helper(a[...], a) -# # `a[...]` was `a` in numpy <1.9. -# numpy_testing_assert_equal_helper(a[...].data_ptr(), a.data_ptr()) + # def test_ellipsis_index(self): + # a = tensor([[1, 2, 3], + # [4, 5, 6], + # [7, 8, 9]]) + # self.assertIsNot(a[...], a) + # numpy_testing_assert_equal_helper(a[...], a) + # # `a[...]` was `a` in numpy <1.9. + # numpy_testing_assert_equal_helper(a[...].data_ptr(), a.data_ptr()) -# # Slicing with ellipsis can skip an -# # arbitrary number of dimensions -# numpy_testing_assert_equal_helper(a[0, ...], a[0]) -# numpy_testing_assert_equal_helper(a[0, ...], a[0, :]) -# numpy_testing_assert_equal_helper(a[..., 0], a[:, 0]) + # # Slicing with ellipsis can skip an + # # arbitrary number of dimensions + # numpy_testing_assert_equal_helper(a[0, ...], a[0]) + # numpy_testing_assert_equal_helper(a[0, ...], a[0, :]) + # numpy_testing_assert_equal_helper(a[..., 0], a[:, 0]) -# # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch -# # we don't have separate 0-dim arrays and scalars. -# numpy_testing_assert_equal_helper(a[0, ..., 1], np.array(2)) + # # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch + # # we don't have separate 0-dim arrays and scalars. + # numpy_testing_assert_equal_helper(a[0, ..., 1], np.array(2)) -# # Assignment with `(Ellipsis,)` on 0-d arrays -# b = np.array(1) -# b[(Ellipsis,)] = 2 -# numpy_testing_assert_equal_helper(b, 2) + # # Assignment with `(Ellipsis,)` on 0-d arrays + # b = np.array(1) + # b[(Ellipsis,)] = 2 + # numpy_testing_assert_equal_helper(b, 2) - def test_single_int_index(self): - # Single integer index selects one row - a = Tensor([[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]) + def test_single_int_index(self): + # Single integer index selects one row + a = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - numpy_testing_assert_equal_helper(a[0], [1, 2, 3]) - numpy_testing_assert_equal_helper(a[-1], [7, 8, 9]) + numpy_testing_assert_equal_helper(a[0], [1, 2, 3]) + numpy_testing_assert_equal_helper(a[-1], [7, 8, 9]) - self.assertRaises(IndexError, a.__getitem__, 1 << 30) - self.assertRaises(IndexError, a.__getitem__, 1 << 64) + self.assertRaises(IndexError, a.__getitem__, 1 << 30) + self.assertRaises(IndexError, a.__getitem__, 1 << 64) - # def test_single_bool_index(self): - # # Single boolean index - # a = Tensor([[1, 2, 3], - # [4, 5, 6], - # [7, 8, 9]]) + # def test_single_bool_index(self): + # # Single boolean index + # a = Tensor([[1, 2, 3], + # [4, 5, 6], + # [7, 8, 9]]) + + # numpy_testing_assert_equal_helper(a[True], a[None]) + # numpy_testing_assert_equal_helper(a[False], a[None][0:0]) - # numpy_testing_assert_equal_helper(a[True], a[None]) - # numpy_testing_assert_equal_helper(a[False], a[None][0:0]) # def test_boolean_shape_mismatch(self): # arr = torch.ones((5, 4, 3)) @@ -1551,5 +1609,5 @@ class TestNumpy(unittest.TestCase): # numpy_testing_assert_equal_helper(kernel, kernel2) -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_bert.py b/test/models/test_bert.py index 0e42ff9dd..6cce535c0 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -4,53 +4,74 @@ import numpy as np from tinygrad.tensor import Tensor import torch + def get_question_samp(bsz, seq_len, vocab_size, seed): - np.random.seed(seed) - in_ids= np.random.randint(vocab_size, size=(bsz, seq_len)) - mask = np.random.choice([True, False], size=(bsz, seq_len)) - seg_ids = np.random.randint(1, size=(bsz, seq_len)) - return in_ids, mask, seg_ids + np.random.seed(seed) + in_ids = np.random.randint(vocab_size, size=(bsz, seq_len)) + mask = np.random.choice([True, False], size=(bsz, seq_len)) + seg_ids = np.random.randint(1, size=(bsz, seq_len)) + return in_ids, mask, seg_ids + def set_equal_weights(mdl, torch_mdl): - from tinygrad.nn.state import get_state_dict - state, torch_state = get_state_dict(mdl), torch_mdl.state_dict() - assert len(state) == len(torch_state) - for k, v in state.items(): - assert k in torch_state - torch_state[k].copy_(torch.from_numpy(v.numpy())) - torch_mdl.eval() + from tinygrad.nn.state import get_state_dict + + state, torch_state = get_state_dict(mdl), torch_mdl.state_dict() + assert len(state) == len(torch_state) + for k, v in state.items(): + assert k in torch_state + torch_state[k].copy_(torch.from_numpy(v.numpy())) + torch_mdl.eval() + class TestBert(unittest.TestCase): - def test_questions(self): - from extra.models.bert import BertForQuestionAnswering - from transformers import BertForQuestionAnswering as TorchBertForQuestionAnswering - from transformers import BertConfig + def test_questions(self): + from extra.models.bert import BertForQuestionAnswering + from transformers import ( + BertForQuestionAnswering as TorchBertForQuestionAnswering, + ) + from transformers import BertConfig - # small - config = { - 'vocab_size':24, 'hidden_size':2, 'num_hidden_layers':2, 'num_attention_heads':2, - 'intermediate_size':32, 'hidden_dropout_prob':0.1, 'attention_probs_dropout_prob':0.1, - 'max_position_embeddings':512, 'type_vocab_size':2 - } + # small + config = { + "vocab_size": 24, + "hidden_size": 2, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "intermediate_size": 32, + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 2, + } - # Create in tinygrad - Tensor.manual_seed(1337) - mdl = BertForQuestionAnswering(**config) + # Create in tinygrad + Tensor.manual_seed(1337) + mdl = BertForQuestionAnswering(**config) - # Create in torch - with torch.no_grad(): - torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config)) + # Create in torch + with torch.no_grad(): + torch_mdl = TorchBertForQuestionAnswering(BertConfig(**config)) - set_equal_weights(mdl, torch_mdl) + set_equal_weights(mdl, torch_mdl) - seeds = (1337, 3141) - bsz, seq_len = 1, 16 - for _, seed in enumerate(seeds): - in_ids, mask, seg_ids = get_question_samp(bsz, seq_len, config['vocab_size'], seed) - out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids)) - torch_out = torch_mdl.forward(torch.from_numpy(in_ids).long(), torch.from_numpy(mask), torch.from_numpy(seg_ids).long())[:2] - torch_out = torch.cat(torch_out).unsqueeze(2) - np.testing.assert_allclose(out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4) + seeds = (1337, 3141) + bsz, seq_len = 1, 16 + for _, seed in enumerate(seeds): + in_ids, mask, seg_ids = get_question_samp( + bsz, seq_len, config["vocab_size"], seed + ) + out = mdl(Tensor(in_ids), Tensor(mask), Tensor(seg_ids)) + torch_out = torch_mdl.forward( + torch.from_numpy(in_ids).long(), + torch.from_numpy(mask), + torch.from_numpy(seg_ids).long(), + )[:2] + torch_out = torch.cat(torch_out).unsqueeze(2) + np.testing.assert_allclose( + out.numpy(), torch_out.detach().numpy(), atol=5e-4, rtol=5e-4 + ) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_efficientnet.py b/test/models/test_efficientnet.py index 048791075..c196d72b6 100644 --- a/test/models/test_efficientnet.py +++ b/test/models/test_efficientnet.py @@ -11,104 +11,117 @@ from extra.models.efficientnet import EfficientNet from extra.models.vit import ViT from extra.models.resnet import ResNet50 + def _load_labels(): - labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt' - return ast.literal_eval(labels_filename.read_text()) + labels_filename = ( + pathlib.Path(__file__).parent / "efficientnet/imagenet1000_clsidx_to_labels.txt" + ) + return ast.literal_eval(labels_filename.read_text()) + _LABELS = _load_labels() + def preprocess(img, new=False): - # preprocess image - aspect_ratio = img.size[0] / img.size[1] - img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) + # preprocess image + aspect_ratio = img.size[0] / img.size[1] + img = img.resize( + (int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0))) + ) - img = np.array(img) - y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2 - img = img[y0: y0 + 224, x0: x0 + 224] + img = np.array(img) + y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2 + img = img[y0 : y0 + 224, x0 : x0 + 224] - # low level preprocess - if new: - img = img.astype(np.float32) - img -= [127.0, 127.0, 127.0] - img /= [128.0, 128.0, 128.0] - img = img[None] - else: - img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) - img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) - img /= 255.0 - img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) - img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) - return img + # low level preprocess + if new: + img = img.astype(np.float32) + img -= [127.0, 127.0, 127.0] + img /= [128.0, 128.0, 128.0] + img = img[None] + else: + img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) + img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) + img /= 255.0 + img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) + img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) + return img def _infer(model: EfficientNet, img, bs=1): - Tensor.training = False - img = preprocess(img) - # run the net - if bs > 1: img = img.repeat(bs, axis=0) - out = model.forward(Tensor(img)).cpu() - return _LABELS[np.argmax(out.numpy()[0])] + Tensor.training = False + img = preprocess(img) + # run the net + if bs > 1: + img = img.repeat(bs, axis=0) + out = model.forward(Tensor(img)).cpu() + return _LABELS[np.argmax(out.numpy()[0])] + + +chicken_img = Image.open(pathlib.Path(__file__).parent / "efficientnet/Chicken.jpg") +car_img = Image.open(pathlib.Path(__file__).parent / "efficientnet/car.jpg") -chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg') -car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') class TestEfficientNet(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = EfficientNet(number=getenv("NUM")) - cls.model.load_from_pretrained() + @classmethod + def setUpClass(cls): + cls.model = EfficientNet(number=getenv("NUM")) + cls.model.load_from_pretrained() - @classmethod - def tearDownClass(cls): - del cls.model + @classmethod + def tearDownClass(cls): + del cls.model - def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "hen") + def test_chicken(self): + label = _infer(self.model, chicken_img) + self.assertEqual(label, "hen") - def test_chicken_bigbatch(self): - label = _infer(self.model, chicken_img, 2) - self.assertEqual(label, "hen") + def test_chicken_bigbatch(self): + label = _infer(self.model, chicken_img, 2) + self.assertEqual(label, "hen") + + def test_car(self): + label = _infer(self.model, car_img) + self.assertEqual(label, "sports car, sport car") - def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "sports car, sport car") class TestViT(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = ViT() - cls.model.load_from_pretrained() + @classmethod + def setUpClass(cls): + cls.model = ViT() + cls.model.load_from_pretrained() - @classmethod - def tearDownClass(cls): - del cls.model + @classmethod + def tearDownClass(cls): + del cls.model - def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "cock") + def test_chicken(self): + label = _infer(self.model, chicken_img) + self.assertEqual(label, "cock") + + def test_car(self): + label = _infer(self.model, car_img) + self.assertEqual(label, "racer, race car, racing car") - def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "racer, race car, racing car") class TestResNet(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = ResNet50() - cls.model.load_from_pretrained() + @classmethod + def setUpClass(cls): + cls.model = ResNet50() + cls.model.load_from_pretrained() - @classmethod - def tearDownClass(cls): - del cls.model + @classmethod + def tearDownClass(cls): + del cls.model - def test_chicken(self): - label = _infer(self.model, chicken_img) - self.assertEqual(label, "hen") + def test_chicken(self): + label = _infer(self.model, chicken_img) + self.assertEqual(label, "hen") - def test_car(self): - label = _infer(self.model, car_img) - self.assertEqual(label, "sports car, sport car") + def test_car(self): + label = _infer(self.model, car_img) + self.assertEqual(label, "sports car, sport car") -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py index 53988917d..cf271d989 100644 --- a/test/models/test_end2end.py +++ b/test/models/test_end2end.py @@ -8,158 +8,218 @@ from tinygrad.tensor import Tensor from extra.datasets import fetch_mnist from tinygrad.helpers import CI + def compare_tiny_torch(model, model_torch, X, Y): - with Tensor.train(): - model_torch.train() - model_state_dict = get_state_dict(model) - for k,v in model_torch.named_parameters(): - if not CI: print(f"initting {k} from torch") - model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() + with Tensor.train(): + model_torch.train() + model_state_dict = get_state_dict(model) + for k, v in model_torch.named_parameters(): + if not CI: + print(f"initting {k} from torch") + model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() - optimizer = optim.SGD(get_parameters(model), lr=0.001) - optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.001) + optimizer = optim.SGD(get_parameters(model), lr=0.001) + optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.001) - Xt = torch.Tensor(X.numpy()) - np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) + Xt = torch.Tensor(X.numpy()) + np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) - out = model(X) - loss = (out * Y).mean() - if not CI: print(loss.realize().numpy()) + out = model(X) + loss = (out * Y).mean() + if not CI: + print(loss.realize().numpy()) - out_torch = model_torch(torch.Tensor(X.numpy())) - loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() - if not CI: print(loss_torch.detach().numpy()) + out_torch = model_torch(torch.Tensor(X.numpy())) + loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() + if not CI: + print(loss_torch.detach().numpy()) - # assert losses match - np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4) + # assert losses match + np.testing.assert_allclose( + loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4 + ) - # zero and backward - optimizer.zero_grad() - loss.backward() - optimizer_torch.zero_grad() - loss_torch.backward() + # zero and backward + optimizer.zero_grad() + loss.backward() + optimizer_torch.zero_grad() + loss_torch.backward() - for k,v in list(model_torch.named_parameters())[::-1]: - g = model_state_dict[k].grad.numpy() - gt = v.grad.detach().numpy() - if not CI: print("testing grads", k) - np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}') + for k, v in list(model_torch.named_parameters())[::-1]: + g = model_state_dict[k].grad.numpy() + gt = v.grad.detach().numpy() + if not CI: + print("testing grads", k) + np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f"grad mismatch {k}") - # take the steps - optimizer.step() - optimizer_torch.step() + # take the steps + optimizer.step() + optimizer_torch.step() + + # assert weights match (they don't!) + for k, v in model_torch.named_parameters(): + if not CI: + print("testing weight", k) + np.testing.assert_allclose( + model_state_dict[k].numpy(), + v.detach().numpy(), + atol=1e-3, + err_msg=f"weight mismatch {k}", + ) - # assert weights match (they don't!) - for k,v in model_torch.named_parameters(): - if not CI: print("testing weight", k) - np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}') def get_mnist_data(): - _X_train, _Y_train, X_test, Y_test = fetch_mnist() - BS = 32 - num_classes = 10 - X = Tensor(X_test[0:BS].astype(np.float32)) - Y = np.zeros((BS, num_classes), np.float32) - Y[range(BS),Y_test[0:BS]] = -1.0*num_classes - return X, Tensor(Y) + _X_train, _Y_train, X_test, Y_test = fetch_mnist() + BS = 32 + num_classes = 10 + X = Tensor(X_test[0:BS].astype(np.float32)) + Y = np.zeros((BS, num_classes), np.float32) + Y[range(BS), Y_test[0:BS]] = -1.0 * num_classes + return X, Tensor(Y) + class TestEnd2End(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.X, cls.Y = get_mnist_data() + @classmethod + def setUpClass(cls): + cls.X, cls.Y = get_mnist_data() - def setUp(self): - torch.manual_seed(123) + def setUp(self): + torch.manual_seed(123) - def test_linear_mnist(self): - class LinTiny: - def __init__(self, has_batchnorm=False): - self.l1 = Linear(784, 128) - self.l2 = Linear(128, 10) - self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x - def __call__(self, x): - return self.l2(self.l1(x)).relu().log_softmax(-1) - class LinTorch(nn.Module): - def __init__(self, has_batchnorm=False): - super().__init__() - self.l1 = nn.Linear(784, 128) - self.l2 = nn.Linear(128, 10) - def forward(self, x): - return self.l2(self.l1(x)).relu().log_softmax(-1) - compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + def test_linear_mnist(self): + class LinTiny: + def __init__(self, has_batchnorm=False): + self.l1 = Linear(784, 128) + self.l2 = Linear(128, 10) + self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x - def test_bn_mnist(self): - class LinTiny: - def __init__(self): - self.l1 = Linear(784, 128) - self.l2 = Linear(128, 10) - self.bn1 = BatchNorm2d(128) - def __call__(self, x): - return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1) - class LinTorch(nn.Module): - def __init__(self): - super().__init__() - self.l1 = nn.Linear(784, 128) - self.l2 = nn.Linear(128, 10) - self.bn1 = nn.BatchNorm2d(128) - def forward(self, x): - return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1) - compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + def __call__(self, x): + return self.l2(self.l1(x)).relu().log_softmax(-1) - def test_bn_alone(self): - np.random.seed(1337) - X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) - Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) - compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y) + class LinTorch(nn.Module): + def __init__(self, has_batchnorm=False): + super().__init__() + self.l1 = nn.Linear(784, 128) + self.l2 = nn.Linear(128, 10) - def test_bn_linear(self): - BS, K = 2, 1 - eps = 0 - X = Tensor([1,0]).reshape(BS, K, 1, 1) - Y = Tensor([-1,0]).reshape(BS, K, 1, 1) - class LinTiny: - def __init__(self): - self.l1 = Conv2d(K, K, 1, bias=False) - self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps) - def __call__(self, x): return self.bn1(self.l1(x)) - class LinTorch(nn.Module): - def __init__(self): - super().__init__() - self.l1 = nn.Conv2d(K, K, 1, bias=False) - self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps) - def forward(self, x): return self.bn1(self.l1(x)) - model_torch = LinTorch() - with torch.no_grad(): - model_torch.l1.weight[:] = 1. - compare_tiny_torch(LinTiny(), model_torch, X, Y) + def forward(self, x): + return self.l2(self.l1(x)).relu().log_softmax(-1) + + compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + + def test_bn_mnist(self): + class LinTiny: + def __init__(self): + self.l1 = Linear(784, 128) + self.l2 = Linear(128, 10) + self.bn1 = BatchNorm2d(128) + + def __call__(self, x): + return self.l2( + self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)) + .reshape(x.shape[0], -1) + .relu() + ).log_softmax(-1) + + class LinTorch(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(784, 128) + self.l2 = nn.Linear(128, 10) + self.bn1 = nn.BatchNorm2d(128) + + def forward(self, x): + return self.l2( + self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)) + .reshape(x.shape[0], -1) + .relu() + ).log_softmax(-1) + + compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + + def test_bn_alone(self): + np.random.seed(1337) + X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) + Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) + compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y) + + def test_bn_linear(self): + BS, K = 2, 1 + eps = 0 + X = Tensor([1, 0]).reshape(BS, K, 1, 1) + Y = Tensor([-1, 0]).reshape(BS, K, 1, 1) + + class LinTiny: + def __init__(self): + self.l1 = Conv2d(K, K, 1, bias=False) + self.bn1 = BatchNorm2d( + K, affine=False, track_running_stats=False, eps=eps + ) + + def __call__(self, x): + return self.bn1(self.l1(x)) + + class LinTorch(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Conv2d(K, K, 1, bias=False) + self.bn1 = nn.BatchNorm2d( + K, affine=False, track_running_stats=False, eps=eps + ) + + def forward(self, x): + return self.bn1(self.l1(x)) + + model_torch = LinTorch() + with torch.no_grad(): + model_torch.l1.weight[:] = 1.0 + compare_tiny_torch(LinTiny(), model_torch, X, Y) + + def test_conv_mnist(self): + class LinTiny: + def __init__(self, has_batchnorm=False): + self.c1 = Conv2d(1, 8, 3, stride=2) + self.c2 = Conv2d(8, 16, 3, stride=2) + self.l1 = Linear(16 * 6 * 6, 10) + if has_batchnorm: + self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x + + def __call__(self, x): + return self.l1( + self.bn2(self.c2(self.bn1(self.c1(x)).relu())) + .relu() + .reshape(x.shape[0], -1) + ).log_softmax(-1) + + class LinTorch(nn.Module): + def __init__(self, has_batchnorm=False): + super().__init__() + self.c1 = nn.Conv2d(1, 8, 3, stride=2) + self.c2 = nn.Conv2d(8, 16, 3, stride=2) + self.l1 = nn.Linear(16 * 6 * 6, 10) + if has_batchnorm: + self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x + + def forward(self, x): + return self.l1( + self.bn2(self.c2(self.bn1(self.c1(x)).relu())) + .relu() + .reshape(x.shape[0], -1) + ).log_softmax(-1) + + for has_batchnorm in [False, True]: + with self.subTest(has_batchnorm=has_batchnorm): + compare_tiny_torch( + LinTiny(has_batchnorm), + LinTorch(has_batchnorm), + self.X.reshape((-1, 1, 28, 28)), + self.Y, + ) - def test_conv_mnist(self): - class LinTiny: - def __init__(self, has_batchnorm=False): - self.c1 = Conv2d(1, 8, 3, stride=2) - self.c2 = Conv2d(8, 16, 3, stride=2) - self.l1 = Linear(16*6*6, 10) - if has_batchnorm: - self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16) - else: - self.bn1, self.bn2 = lambda x: x, lambda x: x - def __call__(self, x): - return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1) - class LinTorch(nn.Module): - def __init__(self, has_batchnorm=False): - super().__init__() - self.c1 = nn.Conv2d(1, 8, 3, stride=2) - self.c2 = nn.Conv2d(8, 16, 3, stride=2) - self.l1 = nn.Linear(16*6*6, 10) - if has_batchnorm: - self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16) - else: - self.bn1, self.bn2 = lambda x: x, lambda x: x - def forward(self, x): - return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1) - for has_batchnorm in [False, True]: - with self.subTest(has_batchnorm=has_batchnorm): - compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/models/test_mnist.py b/test/models/test_mnist.py index d1702bd99..431cb8df8 100644 --- a/test/models/test_mnist.py +++ b/test/models/test_mnist.py @@ -13,104 +13,111 @@ pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang] # load the mnist dataset X_train, Y_train, X_test, Y_test = fetch_mnist() + # create a model class TinyBobNet: - def __init__(self): - self.l1 = Tensor.scaled_uniform(784, 128) - self.l2 = Tensor.scaled_uniform(128, 10) + def __init__(self): + self.l1 = Tensor.scaled_uniform(784, 128) + self.l2 = Tensor.scaled_uniform(128, 10) - def parameters(self): - return get_parameters(self) + def parameters(self): + return get_parameters(self) + + def forward(self, x): + return x.dot(self.l1).relu().dot(self.l2) - def forward(self, x): - return x.dot(self.l1).relu().dot(self.l2) # create a model with a conv layer class TinyConvNet: - def __init__(self, has_batchnorm=False): - # https://keras.io/examples/vision/mnist_convnet/ - conv = 3 - #inter_chan, out_chan = 32, 64 - inter_chan, out_chan = 8, 16 # for speed - self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv) - self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv) - self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10) - if has_batchnorm: - self.bn1 = BatchNorm2d(inter_chan) - self.bn2 = BatchNorm2d(out_chan) - else: - self.bn1, self.bn2 = lambda x: x, lambda x: x + def __init__(self, has_batchnorm=False): + # https://keras.io/examples/vision/mnist_convnet/ + conv = 3 + # inter_chan, out_chan = 32, 64 + inter_chan, out_chan = 8, 16 # for speed + self.c1 = Tensor.scaled_uniform(inter_chan, 1, conv, conv) + self.c2 = Tensor.scaled_uniform(out_chan, inter_chan, conv, conv) + self.l1 = Tensor.scaled_uniform(out_chan * 5 * 5, 10) + if has_batchnorm: + self.bn1 = BatchNorm2d(inter_chan) + self.bn2 = BatchNorm2d(out_chan) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x - def parameters(self): - return get_parameters(self) + def parameters(self): + return get_parameters(self) + + def forward(self, x: Tensor): + x = x.reshape(shape=(-1, 1, 28, 28)) # hacks + x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d() + x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d() + x = x.reshape(shape=[x.shape[0], -1]) + return x.dot(self.l1) - def forward(self, x:Tensor): - x = x.reshape(shape=(-1, 1, 28, 28)) # hacks - x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d() - x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d() - x = x.reshape(shape=[x.shape[0], -1]) - return x.dot(self.l1) class TestMNIST(unittest.TestCase): - def test_sgd_onestep(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=1) - for p in model.parameters(): p.realize() + def test_sgd_onestep(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=1) + for p in model.parameters(): + p.realize() - def test_sgd_threestep(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=3) + def test_sgd_threestep(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=3) - def test_sgd_sixstep(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=6, noloss=True) + def test_sgd_sixstep(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=6, noloss=True) - def test_adam_onestep(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=1) - for p in model.parameters(): p.realize() + def test_adam_onestep(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.Adam(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=1) + for p in model.parameters(): + p.realize() - def test_adam_threestep(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=3) + def test_adam_threestep(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.Adam(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=3) - def test_conv_onestep(self): - np.random.seed(1337) - model = TinyConvNet() - optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True) - for p in model.parameters(): p.realize() + def test_conv_onestep(self): + np.random.seed(1337) + model = TinyConvNet() + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, BS=69, steps=1, noloss=True) + for p in model.parameters(): + p.realize() - def test_conv(self): - np.random.seed(1337) - model = TinyConvNet() - optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=100) - assert evaluate(model, X_test, Y_test) > 0.93 # torch gets 0.9415 sometimes + def test_conv(self): + np.random.seed(1337) + model = TinyConvNet() + optimizer = optim.Adam(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, steps=100) + assert evaluate(model, X_test, Y_test) > 0.93 # torch gets 0.9415 sometimes - def test_conv_with_bn(self): - np.random.seed(1337) - model = TinyConvNet(has_batchnorm=True) - optimizer = optim.AdamW(model.parameters(), lr=0.003) - train(model, X_train, Y_train, optimizer, steps=200) - assert evaluate(model, X_test, Y_test) > 0.94 + def test_conv_with_bn(self): + np.random.seed(1337) + model = TinyConvNet(has_batchnorm=True) + optimizer = optim.AdamW(model.parameters(), lr=0.003) + train(model, X_train, Y_train, optimizer, steps=200) + assert evaluate(model, X_test, Y_test) > 0.94 - def test_sgd(self): - np.random.seed(1337) - model = TinyBobNet() - optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=600) - assert evaluate(model, X_test, Y_test) > 0.94 # CPU gets 0.9494 sometimes + def test_sgd(self): + np.random.seed(1337) + model = TinyBobNet() + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, steps=600) + assert evaluate(model, X_test, Y_test) > 0.94 # CPU gets 0.9494 sometimes -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index 845dc36f2..a05d9fe41 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -11,124 +11,163 @@ import pytest pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang] + def run_onnx_torch(onnx_model, inputs): - import torch - from onnx2torch import convert - torch_model = convert(onnx_model).float() - with torch.no_grad(): - torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()]) - return torch_out + import torch + from onnx2torch import convert + + torch_model = convert(onnx_model).float() + with torch.no_grad(): + torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()]) + return torch_out + OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" np.random.seed(1337) + class TestOnnxModel(unittest.TestCase): - def test_benchmark_openpilot_model(self): - onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) - def get_inputs(): - np_inputs = { - "input_imgs": np.random.randn(*(1, 12, 128, 256)), - "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), - "desire": np.zeros((1, 100, 8)), - "traffic_convention": np.array([[1., 0.]]), - "nav_features": np.zeros((1, 256)), - "features_buffer": np.zeros((1, 99, 128)), - } - inputs = {k:Tensor(v.astype(np.float32), requires_grad=False) for k,v in np_inputs.items()} - return inputs + def test_benchmark_openpilot_model(self): + onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) + run_onnx = get_run_onnx(onnx_model) - for _ in range(7): - inputs = get_inputs() - st = time.monotonic() - tinygrad_out = run_onnx(inputs)['outputs'] - mt = time.monotonic() - tinygrad_out.realize() - mt2 = time.monotonic() - tinygrad_out = tinygrad_out.numpy() - et = time.monotonic() - if not CI: print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") + def get_inputs(): + np_inputs = { + "input_imgs": np.random.randn(*(1, 12, 128, 256)), + "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), + "desire": np.zeros((1, 100, 8)), + "traffic_convention": np.array([[1.0, 0.0]]), + "nav_features": np.zeros((1, 256)), + "features_buffer": np.zeros((1, 99, 128)), + } + inputs = { + k: Tensor(v.astype(np.float32), requires_grad=False) + for k, v in np_inputs.items() + } + return inputs - if not CI: - import cProfile - import pstats - inputs = get_inputs() - pr = cProfile.Profile(timer=time.perf_counter_ns, timeunit=1e-6) - pr.enable() - tinygrad_out = run_onnx(inputs)['outputs'] - tinygrad_out.realize() - tinygrad_out = tinygrad_out.numpy() - if not CI: - pr.disable() - stats = pstats.Stats(pr) - stats.dump_stats(temp("net.prof")) - os.system(f"flameprof {temp('net.prof')} > {temp('prof.svg')}") - ps = stats.sort_stats(pstats.SortKey.TIME) - ps.print_stats(30) + for _ in range(7): + inputs = get_inputs() + st = time.monotonic() + tinygrad_out = run_onnx(inputs)["outputs"] + mt = time.monotonic() + tinygrad_out.realize() + mt2 = time.monotonic() + tinygrad_out = tinygrad_out.numpy() + et = time.monotonic() + if not CI: + print( + f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue" + ) - def test_openpilot_model(self): - onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) - run_onnx = get_run_onnx(onnx_model) - print("got run_onnx") - inputs = { - "input_imgs": np.random.randn(*(1, 12, 128, 256)), - "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), - "desire": np.zeros((1, 100, 8)), - "traffic_convention": np.array([[1., 0.]]), - "nav_features": np.zeros((1, 256)), - "features_buffer": np.zeros((1, 99, 128)), - } - inputs = {k:v.astype(np.float32) for k,v in inputs.items()} + if not CI: + import cProfile + import pstats - st = time.monotonic() - print("****** run onnx ******") - tinygrad_out = run_onnx(inputs)['outputs'] - mt = time.monotonic() - print("****** realize ******") - tinygrad_out.realize() - mt2 = time.monotonic() - tinygrad_out = tinygrad_out.numpy() - et = time.monotonic() - print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") + inputs = get_inputs() + pr = cProfile.Profile(timer=time.perf_counter_ns, timeunit=1e-6) + pr.enable() + tinygrad_out = run_onnx(inputs)["outputs"] + tinygrad_out.realize() + tinygrad_out = tinygrad_out.numpy() + if not CI: + pr.disable() + stats = pstats.Stats(pr) + stats.dump_stats(temp("net.prof")) + os.system(f"flameprof {temp('net.prof')} > {temp('prof.svg')}") + ps = stats.sort_stats(pstats.SortKey.TIME) + ps.print_stats(30) - Tensor.no_grad = True - torch_out = run_onnx_torch(onnx_model, inputs).numpy() - Tensor.no_grad = False - print(tinygrad_out, torch_out) - np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2) + def test_openpilot_model(self): + onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) + run_onnx = get_run_onnx(onnx_model) + print("got run_onnx") + inputs = { + "input_imgs": np.random.randn(*(1, 12, 128, 256)), + "big_input_imgs": np.random.randn(*(1, 12, 128, 256)), + "desire": np.zeros((1, 100, 8)), + "traffic_convention": np.array([[1.0, 0.0]]), + "nav_features": np.zeros((1, 256)), + "features_buffer": np.zeros((1, 99, 128)), + } + inputs = {k: v.astype(np.float32) for k, v in inputs.items()} - def test_efficientnet(self): - input_name, input_new = "images:0", True - self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) + st = time.monotonic() + print("****** run onnx ******") + tinygrad_out = run_onnx(inputs)["outputs"] + mt = time.monotonic() + print("****** realize ******") + tinygrad_out.realize() + mt2 = time.monotonic() + tinygrad_out = tinygrad_out.numpy() + et = time.monotonic() + print( + f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue" + ) - def test_shufflenet(self): - input_name, input_new = "gpu_0/data_0", False - self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx"), input_name, input_new) + Tensor.no_grad = True + torch_out = run_onnx_torch(onnx_model, inputs).numpy() + Tensor.no_grad = False + print(tinygrad_out, torch_out) + np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2) - @unittest.skip("test is very slow") - def test_resnet(self): - # NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size - input_name, input_new = "data", False - self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx"), input_name, input_new) + def test_efficientnet(self): + input_name, input_new = "images:0", True + self._test_model( + fetch( + "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx" + ), + input_name, + input_new, + ) - def _test_model(self, fn, input_name, input_new, debug=False): - onnx_model = onnx.load(fn) - print("onnx loaded") - from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS - run_onnx = get_run_onnx(onnx_model) + def test_shufflenet(self): + input_name, input_new = "gpu_0/data_0", False + self._test_model( + fetch( + "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx" + ), + input_name, + input_new, + ) - def run(img): - inputs = {input_name: preprocess(img, new=input_new)} - tinygrad_out = list(run_onnx(inputs, debug=debug).values())[0].numpy() - return tinygrad_out.argmax() + @unittest.skip("test is very slow") + def test_resnet(self): + # NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size + input_name, input_new = "data", False + self._test_model( + fetch( + "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx" + ), + input_name, + input_new, + ) + + def _test_model(self, fn, input_name, input_new, debug=False): + onnx_model = onnx.load(fn) + print("onnx loaded") + from test.models.test_efficientnet import ( + chicken_img, + car_img, + preprocess, + _LABELS, + ) + + run_onnx = get_run_onnx(onnx_model) + + def run(img): + inputs = {input_name: preprocess(img, new=input_new)} + tinygrad_out = list(run_onnx(inputs, debug=debug).values())[0].numpy() + return tinygrad_out.argmax() + + cls = run(chicken_img) + print(cls, _LABELS[cls]) + assert _LABELS[cls] == "hen" or _LABELS[cls] == "cock" + cls = run(car_img) + print(cls, _LABELS[cls]) + assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible" - cls = run(chicken_img) - print(cls, _LABELS[cls]) - assert _LABELS[cls] == "hen" or _LABELS[cls] == "cock" - cls = run(car_img) - print(cls, _LABELS[cls]) - assert "car" in _LABELS[cls] or _LABELS[cls] == "convertible" if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 4b3cba677..ed049acac 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -9,97 +9,193 @@ from tinygrad.helpers import CI, dtypes from tinygrad.shape.symbolic import Variable from test.helpers import derandomize_model -from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS +from examples.gpt2 import ( + Transformer as GPT2Transformer, + MODEL_PARAMS as GPT2_MODEL_PARAMS, +) from examples.hlb_cifar10 import SpeedyResNet -from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS +from examples.llama import ( + Transformer as LLaMaTransformer, + MODEL_PARAMS as LLAMA_MODEL_PARAMS, +) from examples.stable_diffusion import UNetModel -def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False): - tms = [] - for _ in range(4): - GlobalCounters.reset() - GlobalCounters.mem_used = 0 - Device[Device.DEFAULT].synchronize() - st = time.perf_counter_ns() - train(*gen()) - Device[Device.DEFAULT].synchronize() - tms.append(time.perf_counter_ns() - st) - # TODO: jit should expose this correctly with graph - kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None - print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms") - assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB" - assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels" - if all_jitted: - assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" +def helper_test( + nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False +): + tms = [] + for _ in range(4): + GlobalCounters.reset() + GlobalCounters.mem_used = 0 + Device[Device.DEFAULT].synchronize() + st = time.perf_counter_ns() + train(*gen()) + Device[Device.DEFAULT].synchronize() + tms.append(time.perf_counter_ns() - st) + + # TODO: jit should expose this correctly with graph + kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None + print( + f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms" + ) + assert ( + GlobalCounters.mem_used / 1e9 < max_memory_allowed + ), f"{nm} used more than {max_memory_allowed:.2f} GB" + assert ( + not kernels_used or kernels_used <= max_kernels_allowed + ), f"{nm} used more than {max_kernels_allowed} kernels" + if all_jitted: + assert ( + kernels_used > 0 + and kernels_used == GlobalCounters.kernel_count + or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)) + ), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" + class TestRealWorld(unittest.TestCase): - def setUp(self): - self.old_type = Tensor.default_type - np.random.seed(2002) + def setUp(self): + self.old_type = Tensor.default_type + np.random.seed(2002) - def tearDown(self): - Tensor.default_type = self.old_type + def tearDown(self): + Tensor.default_type = self.old_type - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipIf(CI, "too big for CI") - def test_stable_diffusion(self): - model = UNetModel() - derandomize_model(model) - @TinyJit - def test(t, t2): return model(t, 801, t2).realize() - helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953) + @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") + @unittest.skipIf(CI, "too big for CI") + def test_stable_diffusion(self): + model = UNetModel() + derandomize_model(model) - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1") - def test_llama(self): - Tensor.default_type = dtypes.float16 + @TinyJit + def test(t, t2): + return model(t, 801, t2).realize() - args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} - model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"])) - derandomize_model(model) - @TinyJit - def test(t): return model(t, 0).realize() - # TODO: test first token vs rest properly, also memory test is broken with CacheCollector - helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True) + helper_test( + "test_sd", + lambda: (Tensor.randn(1, 4, 64, 64), Tensor.randn(1, 77, 768)), + test, + 18.0, + 953, + ) - @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") - def test_gpt2(self): - Tensor.default_type = dtypes.float16 + @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") + @unittest.skipIf( + Device.DEFAULT in ["LLVM", "GPU"] and CI, + "too long on CI LLVM, GPU requires cl_khr_fp1", + ) + def test_llama(self): + Tensor.default_type = dtypes.float16 - args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000} - model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"])) - derandomize_model(model) - @TinyJit - def test(t, v): return model(t, v).realize() - helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 180 if CI else 516, all_jitted=True) + args_tiny = { + "dim": 1024, + "hidden_dim": 2048, + "n_heads": 8, + "n_layers": 8, + "norm_eps": 1e-05, + "vocab_size": 1000, + } + model = LLaMaTransformer( + **(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]) + ) + derandomize_model(model) - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG") - def test_train_cifar(self): - # TODO: with default device - #old_default = Device.DEFAULT - #Device.DEFAULT = "FAKE" - #Device['fake'].codegen = Device[old_default].codegen + @TinyJit + def test(t): + return model(t, 0).realize() - with Tensor.train(): - model = SpeedyResNet(Tensor.ones((12,3,2,2))) - optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) + # TODO: test first token vs rest properly, also memory test is broken with CacheCollector + helper_test( + "test_llama", + lambda: (Tensor([[1, 2, 3, 4]]),), + test, + 0.22 if CI else 13.5, + 181 if CI else 685, + all_jitted=True, + ) - BS = 32 if CI else 512 + @unittest.skipIf( + Device.DEFAULT in ["LLVM", "GPU"] and CI, + "too long on CI LLVM, GPU requires cl_khr_fp16", + ) + def test_gpt2(self): + Tensor.default_type = dtypes.float16 - @TinyJit - def train(X): - out = model(X) - loss = out.mean() - optimizer.zero_grad() - loss.backward() - optimizer.step() + args_tiny = { + "dim": 1024, + "n_heads": 8, + "n_layers": 8, + "norm_eps": 1e-5, + "vocab_size": 1000, + } + model = GPT2Transformer( + **(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]) + ) + derandomize_model(model) - helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal + @TinyJit + def test(t, v): + return model(t, v).realize() - # reset device - #Device.DEFAULT = old_default + helper_test( + "test_gpt2", + lambda: ( + Tensor( + [ + [ + 1, + ] + ] + ), + Variable("pos", 1, 100).bind(1), + ), + test, + 0.21 if CI else 0.9, + 180 if CI else 516, + all_jitted=True, + ) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") + @unittest.skipIf( + Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG" + ) + def test_train_cifar(self): + # TODO: with default device + # old_default = Device.DEFAULT + # Device.DEFAULT = "FAKE" + # Device['fake'].codegen = Device[old_default].codegen + + with Tensor.train(): + model = SpeedyResNet(Tensor.ones((12, 3, 2, 2))) + optimizer = optim.SGD( + get_parameters(model), + lr=0.01, + momentum=0.8, + nesterov=True, + weight_decay=0.15, + ) + + BS = 32 if CI else 512 + + @TinyJit + def train(X): + out = model(X) + loss = out.mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + helper_test( + "train_cifar", + lambda: (Tensor.randn(BS, 3, 32, 32),), + train, + (1.0 / 48) * BS, + 142 if CI else 154, + ) # it's 154 on metal + + # reset device + # Device.DEFAULT = old_default + + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_rnnt.py b/test/models/test_rnnt.py index f9d5e2c9d..e053fd5a2 100644 --- a/test/models/test_rnnt.py +++ b/test/models/test_rnnt.py @@ -5,43 +5,49 @@ from tinygrad.tensor import Tensor from extra.models.rnnt import LSTM import torch + class TestRNNT(unittest.TestCase): - def test_lstm(self): - BS, SQ, IS, HS, L = 2, 20, 40, 128, 2 + def test_lstm(self): + BS, SQ, IS, HS, L = 2, 20, 40, 128, 2 - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LSTM(IS, HS, L) + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.LSTM(IS, HS, L) - # create in tinygrad - layer = LSTM(IS, HS, L, 0.0) + # create in tinygrad + layer = LSTM(IS, HS, L, 0.0) - # copy weights - with torch.no_grad(): - layer.cells[0].weights_ih.assign(Tensor(torch_layer.weight_ih_l0.numpy())) - layer.cells[0].weights_hh.assign(Tensor(torch_layer.weight_hh_l0.numpy())) - layer.cells[0].bias_ih.assign(Tensor(torch_layer.bias_ih_l0.numpy())) - layer.cells[0].bias_hh.assign(Tensor(torch_layer.bias_hh_l0.numpy())) - layer.cells[1].weights_ih.assign(Tensor(torch_layer.weight_ih_l1.numpy())) - layer.cells[1].weights_hh.assign(Tensor(torch_layer.weight_hh_l1.numpy())) - layer.cells[1].bias_ih.assign(Tensor(torch_layer.bias_ih_l1.numpy())) - layer.cells[1].bias_hh.assign(Tensor(torch_layer.bias_hh_l1.numpy())) + # copy weights + with torch.no_grad(): + layer.cells[0].weights_ih.assign(Tensor(torch_layer.weight_ih_l0.numpy())) + layer.cells[0].weights_hh.assign(Tensor(torch_layer.weight_hh_l0.numpy())) + layer.cells[0].bias_ih.assign(Tensor(torch_layer.bias_ih_l0.numpy())) + layer.cells[0].bias_hh.assign(Tensor(torch_layer.bias_hh_l0.numpy())) + layer.cells[1].weights_ih.assign(Tensor(torch_layer.weight_ih_l1.numpy())) + layer.cells[1].weights_hh.assign(Tensor(torch_layer.weight_hh_l1.numpy())) + layer.cells[1].bias_ih.assign(Tensor(torch_layer.bias_ih_l1.numpy())) + layer.cells[1].bias_hh.assign(Tensor(torch_layer.bias_hh_l1.numpy())) - # test initial hidden - for _ in range(3): - x = Tensor.randn(SQ, BS, IS) - z, hc = layer(x, None) - torch_x = torch.tensor(x.numpy()) - torch_z, torch_hc = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # test initial hidden + for _ in range(3): + x = Tensor.randn(SQ, BS, IS) + z, hc = layer(x, None) + torch_x = torch.tensor(x.numpy()) + torch_z, torch_hc = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # test passing hidden - for _ in range(3): - x = Tensor.randn(SQ, BS, IS) - z, hc = layer(x, hc) - torch_x = torch.tensor(x.numpy()) - torch_z, torch_hc = torch_layer(torch_x, torch_hc) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # test passing hidden + for _ in range(3): + x = Tensor.randn(SQ, BS, IS) + z, hc = layer(x, hc) + torch_x = torch.tensor(x.numpy()) + torch_z, torch_hc = torch_layer(torch_x, torch_hc) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_train.py b/test/models/test_train.py index 7f4b9161f..4cf7f63ec 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -17,67 +17,76 @@ pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang] BS = getenv("BS", 2) -def train_one_step(model,X,Y): - params = get_parameters(model) - pcount = 0 - for p in params: - pcount += np.prod(p.shape) - optimizer = optim.SGD(params, lr=0.001) - print("stepping %r with %.1fM params bs %d" % (type(model), pcount/1e6, BS)) - st = time.time() - train(model, X, Y, optimizer, steps=1, BS=BS) - et = time.time()-st - print("done in %.2f ms" % (et*1000.)) + +def train_one_step(model, X, Y): + params = get_parameters(model) + pcount = 0 + for p in params: + pcount += np.prod(p.shape) + optimizer = optim.SGD(params, lr=0.001) + print("stepping %r with %.1fM params bs %d" % (type(model), pcount / 1e6, BS)) + st = time.time() + train(model, X, Y, optimizer, steps=1, BS=BS) + et = time.time() - st + print("done in %.2f ms" % (et * 1000.0)) + def check_gc(): - if Device.DEFAULT == "GPU": - from extra.introspection import print_objects - assert print_objects() == 0 + if Device.DEFAULT == "GPU": + from extra.introspection import print_objects + + assert print_objects() == 0 + class TestTrain(unittest.TestCase): - def test_convnext(self): - model = ConvNeXt(depths=[1], dims=[16]) - X = np.zeros((BS,3,224,224), dtype=np.float32) - Y = np.zeros((BS), dtype=np.int32) - train_one_step(model,X,Y) - check_gc() + def test_convnext(self): + model = ConvNeXt(depths=[1], dims=[16]) + X = np.zeros((BS, 3, 224, 224), dtype=np.float32) + Y = np.zeros((BS), dtype=np.int32) + train_one_step(model, X, Y) + check_gc() - def test_efficientnet(self): - model = EfficientNet(0) - X = np.zeros((BS,3,224,224), dtype=np.float32) - Y = np.zeros((BS), dtype=np.int32) - train_one_step(model,X,Y) - check_gc() + def test_efficientnet(self): + model = EfficientNet(0) + X = np.zeros((BS, 3, 224, 224), dtype=np.float32) + Y = np.zeros((BS), dtype=np.int32) + train_one_step(model, X, Y) + check_gc() - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") - def test_vit(self): - model = ViT() - X = np.zeros((BS,3,224,224), dtype=np.float32) - Y = np.zeros((BS,), dtype=np.int32) - train_one_step(model,X,Y) - check_gc() + @unittest.skipIf( + Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal" + ) + def test_vit(self): + model = ViT() + X = np.zeros((BS, 3, 224, 224), dtype=np.float32) + Y = np.zeros((BS,), dtype=np.int32) + train_one_step(model, X, Y) + check_gc() - def test_transformer(self): - # this should be small GPT-2, but the param count is wrong - # (real ff_dim is 768*4) - model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768//4) - X = np.zeros((BS,6), dtype=np.float32) - Y = np.zeros((BS,6), dtype=np.int32) - train_one_step(model,X,Y) - check_gc() + def test_transformer(self): + # this should be small GPT-2, but the param count is wrong + # (real ff_dim is 768*4) + model = Transformer( + syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768 // 4 + ) + X = np.zeros((BS, 6), dtype=np.float32) + Y = np.zeros((BS, 6), dtype=np.int32) + train_one_step(model, X, Y) + check_gc() - def test_resnet(self): - X = np.zeros((BS, 3, 224, 224), dtype=np.float32) - Y = np.zeros((BS), dtype=np.int32) - for resnet_v in [ResNet18]: - model = resnet_v() - model.load_from_pretrained() - train_one_step(model, X, Y) - check_gc() + def test_resnet(self): + X = np.zeros((BS, 3, 224, 224), dtype=np.float32) + Y = np.zeros((BS), dtype=np.int32) + for resnet_v in [ResNet18]: + model = resnet_v() + model.load_from_pretrained() + train_one_step(model, X, Y) + check_gc() - def test_bert(self): - # TODO: write this - pass + def test_bert(self): + # TODO: write this + pass -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_waifu2x.py b/test/models/test_waifu2x.py index 957d34fd5..0ba8788a5 100644 --- a/test/models/test_waifu2x.py +++ b/test/models/test_waifu2x.py @@ -4,21 +4,23 @@ import unittest import numpy as np from tinygrad.tensor import Tensor + class TestVGG7(unittest.TestCase): - def test_vgg7(self): - from examples.vgg7_helpers.waifu2x import Vgg7, image_load + def test_vgg7(self): + from examples.vgg7_helpers.waifu2x import Vgg7, image_load - # Create in tinygrad - Tensor.manual_seed(1337) - mdl = Vgg7() - mdl.load_from_pretrained() + # Create in tinygrad + Tensor.manual_seed(1337) + mdl = Vgg7() + mdl.load_from_pretrained() - # Scale up an image - test_x = image_load(pathlib.Path(__file__).parent / 'waifu2x/input.png') - test_y = image_load(pathlib.Path(__file__).parent / 'waifu2x/output.png') - scaled = mdl.forward_tiled(test_x, 156) - scaled = np.fmax(0, np.fmin(1, scaled)) - np.testing.assert_allclose(scaled, test_y, atol=5e-3, rtol=5e-3) + # Scale up an image + test_x = image_load(pathlib.Path(__file__).parent / "waifu2x/input.png") + test_y = image_load(pathlib.Path(__file__).parent / "waifu2x/output.png") + scaled = mdl.forward_tiled(test_x, 156) + scaled = np.fmax(0, np.fmin(1, scaled)) + np.testing.assert_allclose(scaled, test_y, atol=5e-3, rtol=5e-3) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 323e4dc00..ec1e0d6c4 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -1,6 +1,11 @@ import unittest import pathlib -from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform +from examples.whisper import ( + init_whisper, + load_file_waveform, + transcribe_file, + transcribe_waveform, +) from tinygrad.helpers import CI, fetch from tinygrad import Device @@ -12,55 +17,67 @@ TRANSCRIPTION_1 = "Could you please let me out of the box?" TEST_FILE_2 = str(pathlib.Path(__file__).parent / "whisper/test2.wav") TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transcriptions of varying length." # TODO this file will possibly not survive long. find another 1-2 minute sound file online to transcribe -TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3' +TEST_FILE_3_URL = "https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3" TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." -@unittest.skipIf(CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16") + +@unittest.skipIf( + CI and Device.DEFAULT in ["LLVM", "CLANG", "CPU", "GPU"], + "Not working on LLVM, slow on others. GPU reequires cl_khr_fp16", +) class TestWhisper(unittest.TestCase): - @classmethod - def setUpClass(cls): - model, enc = init_whisper("tiny.en", batch_size=2) - cls.model = model - cls.enc = enc + @classmethod + def setUpClass(cls): + model, enc = init_whisper("tiny.en", batch_size=2) + cls.model = model + cls.enc = enc - @classmethod - def tearDownClass(cls): - del cls.model - del cls.enc + @classmethod + def tearDownClass(cls): + del cls.model + del cls.enc - def test_transcribe_file1(self): - self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1) + def test_transcribe_file1(self): + self.assertEqual( + transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1 + ) - @unittest.skipIf(CI, "too many tests for CI") - def test_transcribe_file2(self): - self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2) + @unittest.skipIf(CI, "too many tests for CI") + def test_transcribe_file2(self): + self.assertEqual( + transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2 + ) - @unittest.skipIf(CI, "too many tests for CI") - def test_transcribe_batch12(self): - waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)] - transcriptions = transcribe_waveform(self.model, self.enc, waveforms) - self.assertEqual(2, len(transcriptions)) - self.assertEqual(TRANSCRIPTION_1, transcriptions[0]) - self.assertEqual(TRANSCRIPTION_2, transcriptions[1]) + @unittest.skipIf(CI, "too many tests for CI") + def test_transcribe_batch12(self): + waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)] + transcriptions = transcribe_waveform(self.model, self.enc, waveforms) + self.assertEqual(2, len(transcriptions)) + self.assertEqual(TRANSCRIPTION_1, transcriptions[0]) + self.assertEqual(TRANSCRIPTION_2, transcriptions[1]) - def test_transcribe_batch21(self): - waveforms = [load_file_waveform(TEST_FILE_2), load_file_waveform(TEST_FILE_1)] - transcriptions = transcribe_waveform(self.model, self.enc, waveforms) - self.assertEqual(2, len(transcriptions)) - self.assertEqual(TRANSCRIPTION_2, transcriptions[0]) - self.assertEqual(TRANSCRIPTION_1, transcriptions[1]) + def test_transcribe_batch21(self): + waveforms = [load_file_waveform(TEST_FILE_2), load_file_waveform(TEST_FILE_1)] + transcriptions = transcribe_waveform(self.model, self.enc, waveforms) + self.assertEqual(2, len(transcriptions)) + self.assertEqual(TRANSCRIPTION_2, transcriptions[0]) + self.assertEqual(TRANSCRIPTION_1, transcriptions[1]) - @unittest.skipIf(CI, "too long for CI") - def test_transcribe_long(self): - waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))] - transcription = transcribe_waveform(self.model, self.enc, waveform) - self.assertEqual(TRANSCRIPTION_3, transcription) + @unittest.skipIf(CI, "too long for CI") + def test_transcribe_long(self): + waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))] + transcription = transcribe_waveform(self.model, self.enc, waveform) + self.assertEqual(TRANSCRIPTION_3, transcription) - @unittest.skipIf(CI, "too long for CI") - def test_transcribe_long_no_batch(self): - waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)] - with self.assertRaises(Exception): - transcribe_waveform(self.model, self.enc, waveforms) + @unittest.skipIf(CI, "too long for CI") + def test_transcribe_long_no_batch(self): + waveforms = [ + load_file_waveform(fetch(TEST_FILE_3_URL)), + load_file_waveform(TEST_FILE_1), + ] + with self.assertRaises(Exception): + transcribe_waveform(self.model, self.enc, waveforms) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_assign.py b/test/test_assign.py index 8a6e13c2b..ae65a26a9 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -7,61 +7,75 @@ from tinygrad.helpers import dtypes N = 200 # has to be bigger than the cache to fail + class TestAssign(unittest.TestCase): - def test_simple_assignment(self): - a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - a.realize() - b.realize() - ba1 = a.lazydata.realized - bb1 = b.lazydata.realized - a += b - a.realize() - ba2 = a.lazydata.realized - assert ba1 == ba2 and ba1 != bb1 - np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N))) + def test_simple_assignment(self): + a = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + b = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + a.realize() + b.realize() + ba1 = a.lazydata.realized + bb1 = b.lazydata.realized + a += b + a.realize() + ba2 = a.lazydata.realized + assert ba1 == ba2 and ba1 != bb1 + np.testing.assert_allclose(a.numpy(), (np.arange(N * N) * 2).reshape((N, N))) - @unittest.skipIf(Device.DEFAULT == "CPU" or Device.DEFAULT == "TORCH", "questionable tests") - def test_permuted_assignment(self): - a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - a.realize() - b.realize() - ba1 = a.lazydata.realized - bb1 = b.lazydata.realized - a = a.permute(1,0) - a += b - a.realize() - ba2 = a.lazydata.realized - assert ba1 != ba2 and ba1 != bb1 - np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + @unittest.skipIf( + Device.DEFAULT == "CPU" or Device.DEFAULT == "TORCH", "questionable tests" + ) + def test_permuted_assignment(self): + a = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + b = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + a.realize() + b.realize() + ba1 = a.lazydata.realized + bb1 = b.lazydata.realized + a = a.permute(1, 0) + a += b + a.realize() + ba2 = a.lazydata.realized + assert ba1 != ba2 and ba1 != bb1 + np.testing.assert_allclose( + a.numpy(), + np.arange(N * N).reshape((N, N)) + + np.arange(N * N).reshape((N, N)).transpose(1, 0), + ) - def test_post_permuted_assignment(self): - a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - b = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - a.realize() - b.realize() - #GlobalCounters.cache = [] - ba1 = a.lazydata.realized # noqa: F841 - bb1 = b.lazydata.realized # noqa: F841 - a.assign(a.permute(1,0) + b) # this should not work! - a.realize() - ba2 = a.lazydata.realized # noqa: F841 - # NOTE: don't test that it's assigned - #assert ba1 == ba2 and ba1 != bb1 - np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + def test_post_permuted_assignment(self): + a = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + b = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + a.realize() + b.realize() + # GlobalCounters.cache = [] + ba1 = a.lazydata.realized # noqa: F841 + bb1 = b.lazydata.realized # noqa: F841 + a.assign(a.permute(1, 0) + b) # this should not work! + a.realize() + ba2 = a.lazydata.realized # noqa: F841 + # NOTE: don't test that it's assigned + # assert ba1 == ba2 and ba1 != bb1 + np.testing.assert_allclose( + a.numpy(), + np.arange(N * N).reshape((N, N)) + + np.arange(N * N).reshape((N, N)).transpose(1, 0), + ) - # TODO: is there a way to sneak in a permute such that it returns the wrong answer? + # TODO: is there a way to sneak in a permute such that it returns the wrong answer? + + def test_cast_assignment(self): + a = Tensor(np.arange(N * N, dtype=np.float32)).reshape(N, N) + a.realize() + oba1 = a.lazydata.output_buffer + a.assign(a.cast(dtypes.int32).realize()) + a.realize() + oba2 = a.lazydata.output_buffer + assert oba1 is None and oba2 is None + np.testing.assert_allclose( + a.numpy(), np.arange(N * N, dtype=np.int32).reshape((N, N)) + ) - def test_cast_assignment(self): - a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) - a.realize() - oba1 = a.lazydata.output_buffer - a.assign(a.cast(dtypes.int32).realize()) - a.realize() - oba2 = a.lazydata.output_buffer - assert oba1 is None and oba2 is None - np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N))) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_conv.py b/test/test_conv.py index 6adde7af7..48b1ecda9 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -2,143 +2,160 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device + class TestConv(unittest.TestCase): - def test_simple(self): - x = Tensor.ones(1,12,128,256).contiguous().realize() - w = Tensor.ones(32,12,3,3).contiguous().realize() - ret = x.conv2d(w, stride=(2,2), padding=(1,1)).numpy() - # it's not 108 around the padding - assert (ret[:, :, 1:-1, 1:-1] == 108).all() - assert ret[0,0,0,0] == 48 - assert ret[0,0,0,1] == 72 + def test_simple(self): + x = Tensor.ones(1, 12, 128, 256).contiguous().realize() + w = Tensor.ones(32, 12, 3, 3).contiguous().realize() + ret = x.conv2d(w, stride=(2, 2), padding=(1, 1)).numpy() + # it's not 108 around the padding + assert (ret[:, :, 1:-1, 1:-1] == 108).all() + assert ret[0, 0, 0, 0] == 48 + assert ret[0, 0, 0, 1] == 72 - def test_simple_rand(self): - x = Tensor.rand(1,12,128,256) - w = Tensor.rand(32,12,3,3) - x.conv2d(w, stride=(2,2), padding=(1,1)).numpy() + def test_simple_rand(self): + x = Tensor.rand(1, 12, 128, 256) + w = Tensor.rand(32, 12, 3, 3) + x.conv2d(w, stride=(2, 2), padding=(1, 1)).numpy() - def test_many_simple(self): - x = Tensor(np.arange(8*2*8).reshape(1,8,2,8).astype(np.float32)) - #w = Tensor(np.arange(8*8*1*1).reshape(8,8,1,1).astype(np.float32)) - w = Tensor.eye(8).reshape((8,8,1,1)) - ret = x.conv2d(w, stride=(1,2), padding=(0,0)).numpy() - print(ret) + def test_many_simple(self): + x = Tensor(np.arange(8 * 2 * 8).reshape(1, 8, 2, 8).astype(np.float32)) + # w = Tensor(np.arange(8*8*1*1).reshape(8,8,1,1).astype(np.float32)) + w = Tensor.eye(8).reshape((8, 8, 1, 1)) + ret = x.conv2d(w, stride=(1, 2), padding=(0, 0)).numpy() + print(ret) - def test_lazycache(self): - Tensor.no_grad = True - x = Tensor.rand(1, 32) - y = Tensor.rand(32) - out = x + y.reshape((1,32,1)).reshape((1,32)) + y.reshape((1,32,1)).reshape((1,32)) - out.numpy() - Tensor.no_grad = False + def test_lazycache(self): + Tensor.no_grad = True + x = Tensor.rand(1, 32) + y = Tensor.rand(32) + out = ( + x + + y.reshape((1, 32, 1)).reshape((1, 32)) + + y.reshape((1, 32, 1)).reshape((1, 32)) + ) + out.numpy() + Tensor.no_grad = False - def test_simple_biased(self): - C = 8 - x = Tensor.rand(1,C,5,5) - w = Tensor.eye(C).reshape((C,C,1,1)) - b = Tensor(np.arange(C).astype(np.float32)) - ret = Tensor.conv2d(x,w,b).relu().conv2d(w,b) + def test_simple_biased(self): + C = 8 + x = Tensor.rand(1, C, 5, 5) + w = Tensor.eye(C).reshape((C, C, 1, 1)) + b = Tensor(np.arange(C).astype(np.float32)) + ret = Tensor.conv2d(x, w, b).relu().conv2d(w, b) - print(ret.numpy()) + print(ret.numpy()) - def test_two_binops_no_rerun(self): - Tensor.no_grad = True - x = Tensor.randn(1,12,128,256) - w = Tensor.randn(32,12,3,3) - out = x.conv2d(w, stride=(2,2), padding=(1,1)) - r1, r2 = out.relu(), (out-1) - np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) - np.testing.assert_allclose(r2.numpy(), out.numpy() - 1) - Tensor.no_grad = False + def test_two_binops_no_rerun(self): + Tensor.no_grad = True + x = Tensor.randn(1, 12, 128, 256) + w = Tensor.randn(32, 12, 3, 3) + out = x.conv2d(w, stride=(2, 2), padding=(1, 1)) + r1, r2 = out.relu(), (out - 1) + np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) + np.testing.assert_allclose(r2.numpy(), out.numpy() - 1) + Tensor.no_grad = False - def test_two_overlapping_binops_no_rerun(self): - Tensor.no_grad = True - x = Tensor.randn(1,12,128,256) - w = Tensor.randn(32,12,3,3) - out = x.conv2d(w, stride=(2,2), padding=(1,1)) - r1, r2 = out.relu(), out.elu() - np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) - np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5) - Tensor.no_grad = False + def test_two_overlapping_binops_no_rerun(self): + Tensor.no_grad = True + x = Tensor.randn(1, 12, 128, 256) + w = Tensor.randn(32, 12, 3, 3) + out = x.conv2d(w, stride=(2, 2), padding=(1, 1)) + r1, r2 = out.relu(), out.elu() + np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) + np.testing.assert_allclose( + r2.numpy(), + np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), + atol=1e-5, + ) + Tensor.no_grad = False - @unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends") - def test_two_overlapping_binops_no_rerun_wino(self): - Tensor.no_grad = True - old_wino = Tensor.wino - Tensor.wino = True - x = Tensor.randn(1,4,16,16) - w = Tensor.randn(6,4,3,3) - out = x.conv2d(w, padding=(1,1)) - r1, r2 = out.relu(), out.elu() - np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) - np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5) - Tensor.wino = old_wino - Tensor.no_grad = False + @unittest.skipIf( + Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends" + ) + def test_two_overlapping_binops_no_rerun_wino(self): + Tensor.no_grad = True + old_wino = Tensor.wino + Tensor.wino = True + x = Tensor.randn(1, 4, 16, 16) + w = Tensor.randn(6, 4, 3, 3) + out = x.conv2d(w, padding=(1, 1)) + r1, r2 = out.relu(), out.elu() + np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0)) + np.testing.assert_allclose( + r2.numpy(), + np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), + atol=1e-5, + ) + Tensor.wino = old_wino + Tensor.no_grad = False - def test_first_three(self): - Tensor.no_grad = True - x = Tensor.rand(1,12,128,256) + def test_first_three(self): + Tensor.no_grad = True + x = Tensor.rand(1, 12, 128, 256) - w = Tensor.rand(32,12,3,3) - x = x.conv2d(w, stride=(2,2), padding=(1,1)).elu() + w = Tensor.rand(32, 12, 3, 3) + x = x.conv2d(w, stride=(2, 2), padding=(1, 1)).elu() - w = Tensor.rand(32,1,3,3) - x = x.conv2d(w, padding=(1,1), groups=32).elu() + w = Tensor.rand(32, 1, 3, 3) + x = x.conv2d(w, padding=(1, 1), groups=32).elu() - w = Tensor.rand(16,32,1,1) - x = x.conv2d(w).elu() + w = Tensor.rand(16, 32, 1, 1) + x = x.conv2d(w).elu() - x = x.numpy() - print(x.shape) - Tensor.no_grad = False + x = x.numpy() + print(x.shape) + Tensor.no_grad = False - def test_elu(self): - Tensor.no_grad = True - x = Tensor.rand(1,12,128,256) + def test_elu(self): + Tensor.no_grad = True + x = Tensor.rand(1, 12, 128, 256) - w = Tensor.rand(32,12,3,3) - x = x.conv2d(w, stride=(2,2), padding=(1,1)) + w = Tensor.rand(32, 12, 3, 3) + x = x.conv2d(w, stride=(2, 2), padding=(1, 1)) - x = x.elu() + x = x.elu() - w = Tensor.rand(32,1,3,3) - x = x.conv2d(w, padding=(1,1), groups=32) - x.numpy() - Tensor.no_grad = False + w = Tensor.rand(32, 1, 3, 3) + x = x.conv2d(w, padding=(1, 1), groups=32) + x.numpy() + Tensor.no_grad = False - def test_reduce_relu(self): - Tensor.no_grad = True - x = Tensor.rand(1,12,128,256) - x = x.sum(keepdim=True).relu() - x.numpy() - Tensor.no_grad = False + def test_reduce_relu(self): + Tensor.no_grad = True + x = Tensor.rand(1, 12, 128, 256) + x = x.sum(keepdim=True).relu() + x.numpy() + Tensor.no_grad = False - def test_bias(self): - Tensor.no_grad = True - from tinygrad.nn import Conv2d - x = Tensor.rand(1,12,128,256) - c = Conv2d(12, 32, 3) - x = c(x).relu() - w = Tensor.uniform(32, 1, 3, 3) - x = x.conv2d(w, groups=32) - x.numpy() - Tensor.no_grad = False + def test_bias(self): + Tensor.no_grad = True + from tinygrad.nn import Conv2d - def test_multiadd(self): - w = Tensor.rand(32) - x = Tensor.rand(32).relu() - (w+x).numpy() + x = Tensor.rand(1, 12, 128, 256) + c = Conv2d(12, 32, 3) + x = c(x).relu() + w = Tensor.uniform(32, 1, 3, 3) + x = x.conv2d(w, groups=32) + x.numpy() + Tensor.no_grad = False - def test_reorder(self): - x = Tensor.rand(1,12,128,256) - w = Tensor.rand(12,12,3,3) - x = x.conv2d(w, padding=(1,1)) - print(x.shape) - x = x.reshape((1, 12, 256, 128)) - x += 1 - x += 1 - x = x.reshape((1, 12, 128, 256)) - x.numpy() + def test_multiadd(self): + w = Tensor.rand(32) + x = Tensor.rand(32).relu() + (w + x).numpy() -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def test_reorder(self): + x = Tensor.rand(1, 12, 128, 256) + w = Tensor.rand(12, 12, 3, 3) + x = x.conv2d(w, padding=(1, 1)) + print(x.shape) + x = x.reshape((1, 12, 256, 128)) + x += 1 + x += 1 + x = x.reshape((1, 12, 128, 256)) + x.numpy() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 9cf7ff24a..3564a3b0a 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -4,20 +4,28 @@ from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps from tinygrad.nn import Conv2d + class TestConvShapetracker(unittest.TestCase): - def test_conv_3x3_one_view(self): - conv = Conv2d(16, 32, (3, 3)) - seen = set() + def test_conv_3x3_one_view(self): + conv = Conv2d(16, 32, (3, 3)) + seen = set() - # first run to init the weights, they are saved in seen - conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) - # run it again to get the kernels - sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) if si.ast.op not in LoadOps] - assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" - print(sched[0]) - for arg in [sched[0].out, *sched[0].inputs]: - print(arg.st) - assert len(arg.st.views) == 1 + # first run to init the weights, they are saved in seen + conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) + # run it again to get the kernels + sched = [ + si + for si in conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) + if si.ast.op not in LoadOps + ] + assert ( + len(sched) == 1 + ), f"conv should only have one kernel, getting {len(sched)}" + print(sched[0]) + for arg in [sched[0].out, *sched[0].inputs]: + print(arg.st) + assert len(arg.st.views) == 1 -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_copy_speed.py b/test/test_copy_speed.py index bd2715357..f9f8e882f 100644 --- a/test/test_copy_speed.py +++ b/test/test_copy_speed.py @@ -5,64 +5,75 @@ from tinygrad.helpers import Timing, CI, OSX import multiprocessing.shared_memory as shared_memory N = 4096 if CI else 16384 + + class TestCopySpeed(unittest.TestCase): - @classmethod - def setUpClass(cls): Device[Device.DEFAULT].synchronize() - - @unittest.skipIf(OSX, "no shm on OSX") - def testCopySHMtoDefault(self): - s = shared_memory.SharedMemory(name="test_X", create=True, size=N*N*4) - s.close() - if CI: - t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize() - else: - t = Tensor.empty(N, N, device="disk:shm:test_X").realize() - for _ in range(3): - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): - with Timing("queue: "): - t.to(Device.DEFAULT).realize() - Device[Device.DEFAULT].synchronize() - s.unlink() - - def testCopyCPUtoDefault(self): - t = Tensor.rand(N, N, device="cpu").realize() - print(f"buffer: {t.nbytes()*1e-9:.2f} GB") - for _ in range(3): - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): - with Timing("queue: "): - t.to(Device.DEFAULT).realize() + @classmethod + def setUpClass(cls): Device[Device.DEFAULT].synchronize() - def testCopyCPUtoDefaultFresh(self): - print("fresh copy") - for _ in range(3): - t = Tensor.rand(N, N, device="cpu").realize() - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): # noqa: F821 - with Timing("queue: "): - t.to(Device.DEFAULT).realize() - Device[Device.DEFAULT].synchronize() - del t + @unittest.skipIf(OSX, "no shm on OSX") + def testCopySHMtoDefault(self): + s = shared_memory.SharedMemory(name="test_X", create=True, size=N * N * 4) + s.close() + if CI: + t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize() + else: + t = Tensor.empty(N, N, device="disk:shm:test_X").realize() + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() + s.unlink() - def testCopyDefaulttoCPU(self): - t = Tensor.rand(N, N).realize() - print(f"buffer: {t.nbytes()*1e-9:.2f} GB") - for _ in range(3): - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): - t.to('cpu').realize() + def testCopyCPUtoDefault(self): + t = Tensor.rand(N, N, device="cpu").realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() - @unittest.skipIf(CI, "CI doesn't have 6 GPUs") - @unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU") - def testCopyCPUto6GPUs(self): - from tinygrad.runtime.ops_gpu import CLDevice - if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs") - t = Tensor.rand(N, N, device="cpu").realize() - print(f"buffer: {t.nbytes()*1e-9:.2f} GB") - for _ in range(3): - with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"): - with Timing("queue: "): - for g in range(6): - t.to(f"gpu:{g}").realize() - Device["gpu"].synchronize() + def testCopyCPUtoDefaultFresh(self): + print("fresh copy") + for _ in range(3): + t = Tensor.rand(N, N, device="cpu").realize() + with Timing( + "sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s" + ): # noqa: F821 + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() + del t -if __name__ == '__main__': - unittest.main() + def testCopyDefaulttoCPU(self): + t = Tensor.rand(N, N).realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + t.to("cpu").realize() + + @unittest.skipIf(CI, "CI doesn't have 6 GPUs") + @unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU") + def testCopyCPUto6GPUs(self): + from tinygrad.runtime.ops_gpu import CLDevice + + if len(CLDevice.device_ids) != 6: + raise unittest.SkipTest("computer doesn't have 6 GPUs") + t = Tensor.rand(N, N, device="cpu").realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing( + "sync: ", + on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)", + ): + with Timing("queue: "): + for g in range(6): + t.to(f"gpu:{g}").realize() + Device["gpu"].synchronize() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_custom_function.py b/test/test_custom_function.py index a844bf2f2..f478a8e30 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -12,16 +12,27 @@ from tinygrad.lazy import Buffer, create_lazybuffer from tinygrad.device import CompiledASTRunner, Device from tinygrad.shape.shapetracker import ShapeTracker + # we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer -def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer): - assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32" - CompiledASTRunner(None, "atan2_gpu", """ +def atan2_gpu(ret: Buffer, a: Buffer, b: Buffer): + assert ( + a.dtype == b.dtype and a.dtype == dtypes.float32 + ), "gpu function only supports float32" + CompiledASTRunner( + None, + "atan2_gpu", + """ __kernel void atan2_gpu(global float *c, global float *a, global float *b) { int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); - }""", global_size=[ret.size]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret, a, b]) + }""", + global_size=[ret.size], + ).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret, a, b]) + + +def atan2_cpu(ret: Buffer, a: Buffer, b: Buffer): + ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements="C").data) -def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data) # *** second, we write the ATan2 mlop *** # NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative @@ -31,71 +42,114 @@ from tinygrad.ops import LazyOp, LoadOps, BinaryOps from tinygrad.lazy import LazyBuffer from tinygrad.tensor import Function + class ATan2(Function): - def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer: - assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" - self.a, self.b = a, b - ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), LoadOps, ast, max(a.dtype, b.dtype)) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)) - return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ - grad_output.e(BinaryOps.MUL, self.a.const(0).e(BinaryOps.SUB, self.a).e(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None + def forward(self, a: LazyBuffer, b: LazyBuffer) -> LazyBuffer: + assert ( + prod(a.shape) == prod(b.shape) and a.device == b.device + ), "shape or device mismatch" + self.a, self.b = a, b + ast = LazyOp( + LoadOps.CUSTOM, + (a.contiguous(), b.contiguous()), + {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device], + ) + return create_lazybuffer( + a.device, + ShapeTracker.from_shape(a.shape), + LoadOps, + ast, + max(a.dtype, b.dtype), + ) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + denom = (self.a.e(BinaryOps.MUL, self.a)).e( + BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b) + ) + return ( + grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) + if self.needs_input_grad[0] + else None, + grad_output.e( + BinaryOps.MUL, + self.a.const(0).e(BinaryOps.SUB, self.a).e(BinaryOps.DIV, denom), + ) + if self.needs_input_grad[1] + else None, + ) + # *** third, we use our lovely new mlop in some tests *** from tinygrad.tensor import Tensor -@unittest.skipUnless(Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU") + +@unittest.skipUnless( + Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU" +) class TestCustomFunction(unittest.TestCase): - def test_atan2_forward(self): - # create some random Tensors, permute them just because we can - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) + def test_atan2_forward(self): + # create some random Tensors, permute them just because we can + a = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) + b = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) - # run the forward pass. note: up until the .numpy(), it's all lazy - c = ATan2.apply(a, b) - print(c.numpy()) + # run the forward pass. note: up until the .numpy(), it's all lazy + c = ATan2.apply(a, b) + print(c.numpy()) - # check the forward pass (in numpy) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) + # check the forward pass (in numpy) + np.testing.assert_allclose( + c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5 + ) - # fun fact, this never actually calls forward, so it works in all the backends - def test_atan2_backward(self): - # have to go forward before we can go backward - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - c = ATan2.apply(a, b) + # fun fact, this never actually calls forward, so it works in all the backends + def test_atan2_backward(self): + # have to go forward before we can go backward + a = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) + b = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) + c = ATan2.apply(a, b) - # run the backward pass - c.mean().backward() - assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients" - print(a.grad.numpy()) - print(b.grad.numpy()) + # run the backward pass + c.mean().backward() + assert ( + a.grad is not None and b.grad is not None + ), "tinygrad didn't compute gradients" + print(a.grad.numpy()) + print(b.grad.numpy()) - # check the backward pass (in torch) - import torch - ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True) - tc = torch.atan2(ta, tb) - tc.mean().backward() - assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients" - np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) - np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) + # check the backward pass (in torch) + import torch - @unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable") - def test_atan2_jit(self): - # custom ops even work in the JIT! - from tinygrad.jit import TinyJit + ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor( + b.numpy(), requires_grad=True + ) + tc = torch.atan2(ta, tb) + tc.mean().backward() + assert ( + ta.grad is not None and tb.grad is not None + ), "torch didn't compute gradients" + np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) + np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) - @TinyJit - def jitted_atan2(a:Tensor, b:Tensor) -> Tensor: - return ATan2.apply(a, b).realize() + @unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable") + def test_atan2_jit(self): + # custom ops even work in the JIT! + from tinygrad.jit import TinyJit + + @TinyJit + def jitted_atan2(a: Tensor, b: Tensor) -> Tensor: + return ATan2.apply(a, b).realize() + + for _ in range(5): + a = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) + b = Tensor.randn(4, 4, requires_grad=True).permute(1, 0) + c = jitted_atan2(a, b) + np.testing.assert_allclose( + c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5 + ) - for _ in range(5): - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - c = jitted_atan2(a, b) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_dtype.py b/test/test_dtype.py index 9b233a6dc..22d709c21 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,193 +1,442 @@ import unittest import numpy as np -from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp +from tinygrad.helpers import ( + CI, + DTYPES_DICT, + getenv, + DType, + DEBUG, + ImageDType, + PtrDType, + OSX, + temp, +) from tinygrad import Device from tinygrad.tensor import Tensor, dtypes from typing import Any, List + def is_dtype_supported(dtype: DType): - # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) - # for LLVM, it segfaults because it can't link to the casting function - if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1 - if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType - if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU") - if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"] - if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] - if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"] - if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] - if dtype == dtypes.bool: - # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable - if Device.DEFAULT == "WEBGPU": return False - return True + # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) + # for LLVM, it segfaults because it can't link to the casting function + if dtype == dtypes.half: + return ( + not (CI and Device.DEFAULT in ["GPU", "LLVM"]) + and Device.DEFAULT != "WEBGPU" + and getenv("CUDACPU") != 1 + ) + if dtype == dtypes.bfloat16: + return ( + False # numpy doesn't support bf16, tested separately in TestBFloat16DType + ) + if dtype == dtypes.float64: + return Device.DEFAULT not in ["WEBGPU", "METAL"] and ( + not OSX and Device.DEFAULT == "GPU" + ) + if dtype in [dtypes.int8, dtypes.uint8]: + return Device.DEFAULT not in ["WEBGPU"] + if dtype in [dtypes.int16, dtypes.uint16]: + return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.uint32: + return Device.DEFAULT not in ["TORCH"] + if dtype in [dtypes.int64, dtypes.uint64]: + return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.bool: + # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable + if Device.DEFAULT == "WEBGPU": + return False + return True -def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes -def _test_to_np(a:Tensor, np_dtype, target): - if DEBUG >= 2: print(a) - na = a.numpy() - if DEBUG >= 2: print(na, na.dtype, a.lazydata.realized) - try: - assert na.dtype == np_dtype - np.testing.assert_allclose(na, target) - except AssertionError as e: - raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e +def get_available_cast_dtypes(dtype: DType) -> List[DType]: + return [ + v + for k, v in DTYPES_DICT.items() + if v != dtype and is_dtype_supported(v) and not k.startswith("_") + ] # dont cast internal dtypes -def _assert_eq(tensor:Tensor, target_dtype:DType, target): - if DEBUG >= 2: print(tensor.numpy()) - try: - assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target) - except AssertionError as e: - raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e -def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) -def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) -def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) +def _test_to_np(a: Tensor, np_dtype, target): + if DEBUG >= 2: + print(a) + na = a.numpy() + if DEBUG >= 2: + print(na, na.dtype, a.lazydata.realized) + try: + assert na.dtype == np_dtype + np.testing.assert_allclose(na, target) + except AssertionError as e: + raise AssertionError( + f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}" + ) from e + + +def _assert_eq(tensor: Tensor, target_dtype: DType, target): + if DEBUG >= 2: + print(tensor.numpy()) + try: + assert tensor.dtype == target_dtype + np.testing.assert_allclose(tensor.numpy(), target) + except AssertionError as e: + raise AssertionError( + f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}" + ) from e + + +def _test_op(fxn, target_dtype: DType, target): + _assert_eq(fxn(), target_dtype, target) + + +def _test_cast(a: Tensor, target_dtype: DType): + _test_op( + lambda: a.cast(target_dtype), + target_dtype, + a.numpy().astype(target_dtype.np).tolist(), + ) + + +def _test_bitcast(a: Tensor, target_dtype: DType, target): + _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) + class TestDType(unittest.TestCase): - DTYPE: Any = None - DATA: Any = None - @classmethod - def setUpClass(cls): - if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") - cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist() - def setUp(self): - if self.DTYPE is None: raise unittest.SkipTest("base class") + DTYPE: Any = None + DATA: Any = None - def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np)) + @classmethod + def setUpClass(cls): + if not is_dtype_supported(cls.DTYPE): + raise unittest.SkipTest("dtype not supported") + cls.DATA = ( + np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() + if dtypes.is_int(cls.DTYPE) + else np.random.choice([True, False], size=10).tolist() + if cls.DTYPE == dtypes.bool + else np.random.uniform(0, 1, size=10).tolist() + ) - def test_casts_to(self): list(map( - lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), - get_available_cast_dtypes(self.DTYPE) - )) - def test_casts_from(self): list(map( - lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype), - get_available_cast_dtypes(self.DTYPE) - )) + def setUp(self): + if self.DTYPE is None: + raise unittest.SkipTest("base class") - def test_same_size_ops(self): - def get_target_dtype(dtype): - if any([dtypes.is_float(dtype), dtypes.is_float(self.DTYPE)]): return max([dtype, self.DTYPE], key=lambda x: x.priority) - return dtype if dtypes.is_unsigned(dtype) else self.DTYPE - list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=get_target_dtype(dtype)) if dtype.itemsize == self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) - def test_upcast_ops(self): list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) - def test_upcast_to_ops(self): - list(map( - lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) + def test_to_np(self): + _test_to_np( + Tensor(self.DATA, dtype=self.DTYPE), + self.DTYPE.np, + np.array(self.DATA, dtype=self.DTYPE.np), + ) + + def test_casts_to(self): + list( + map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), + get_available_cast_dtypes(self.DTYPE), + ) + ) + + def test_casts_from(self): + list( + map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype), + get_available_cast_dtypes(self.DTYPE), + ) + ) + + def test_same_size_ops(self): + def get_target_dtype(dtype): + if any([dtypes.is_float(dtype), dtypes.is_float(self.DTYPE)]): + return max([dtype, self.DTYPE], key=lambda x: x.priority) + return dtype if dtypes.is_unsigned(dtype) else self.DTYPE + + list( + map( + lambda dtype: _test_ops( + a_dtype=self.DTYPE, + b_dtype=dtype, + target_dtype=get_target_dtype(dtype), + ) + if dtype.itemsize == self.DTYPE.itemsize + else None, + get_available_cast_dtypes(self.DTYPE), + ) + ) + + def test_upcast_ops(self): + list( + map( + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) + if dtype.itemsize > self.DTYPE.itemsize + else None, + get_available_cast_dtypes(self.DTYPE), + ) + ) + + def test_upcast_to_ops(self): + list( + map( + lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) + if dtype.itemsize < self.DTYPE.itemsize + else None, + get_available_cast_dtypes(self.DTYPE), + ) + ) + + +def _test_ops(a_dtype: DType, b_dtype: DType, target_dtype=None): + if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): + return + if a_dtype == dtypes.bool or b_dtype == dtypes.bool: + return + target_dtype = target_dtype or ( + max([a_dtype, b_dtype], key=lambda x: x.priority) + if a_dtype.priority != b_dtype.priority + else max([a_dtype, b_dtype], key=lambda x: x.itemsize) + ) + _assert_eq( + Tensor([1, 2, 3, 4], dtype=a_dtype) + Tensor([1, 2, 3, 4], dtype=b_dtype), + target_dtype, + [2, 4, 6, 8], + ) + _assert_eq( + Tensor([1, 2, 3, 4], dtype=a_dtype) * Tensor([1, 2, 3, 4], dtype=b_dtype), + target_dtype, + [1, 4, 9, 16], + ) + _assert_eq( + Tensor([[1, 2], [3, 4]], dtype=a_dtype) @ Tensor.eye(2, dtype=b_dtype), + target_dtype, + [[1, 2], [3, 4]], + ) + _assert_eq( + Tensor([1, 1, 1, 1], dtype=a_dtype) + Tensor.ones((4, 4), dtype=b_dtype), + target_dtype, + 2 * Tensor.ones(4, 4).numpy(), + ) -def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): - if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return - if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return - target_dtype = target_dtype or (max([a_dtype, b_dtype], key=lambda x: x.priority) if a_dtype.priority != b_dtype.priority else max([a_dtype, b_dtype], key=lambda x: x.itemsize)) - _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) - _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) - _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) - _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) class TestBFloat16DType(unittest.TestCase): - def setUp(self): - if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported") - def test_bf16_to_float(self): - with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000]) + def setUp(self): + if not is_dtype_supported(dtypes.bfloat16): + raise unittest.SkipTest("bfloat16 not supported") - def test_float_to_bf16(self): - with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000]) + def test_bf16_to_float(self): + with self.assertRaises(AssertionError): + _test_cast( + Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000] + ) - # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) + def test_float_to_bf16(self): + with self.assertRaises(AssertionError): + _test_cast( + Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000] + ) - def test_bf16(self): - t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) - t.realize() - back = t.cast(dtypes.float32) - assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) - def test_bf16_disk_write_read(self): - t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) - t.to(f"disk:{temp('f32')}").realize() + def test_bf16(self): + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) + t.realize() + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984.0, -1, -1000, -9984, 20) - # hack to "cast" f32 -> bf16 - dat = open(temp('f32'), "rb").read() - adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)]) - with open(temp('bf16'), "wb") as f: f.write(adat) + def test_bf16_disk_write_read(self): + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) + t.to(f"disk:{temp('f32')}").realize() - t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize() - back = t.cast(dtypes.float32) - assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + # hack to "cast" f32 -> bf16 + dat = open(temp("f32"), "rb").read() + adat = b"".join([dat[i + 2 : i + 4] for i in range(0, len(dat), 4)]) + with open(temp("bf16"), "wb") as f: + f.write(adat) -class TestHalfDtype(TestDType): DTYPE = dtypes.half + t = ( + Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}") + .llvm() + .realize() + ) + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984.0, -1, -1000, -9984, 20) -class TestFloatDType(TestDType): DTYPE = dtypes.float -class TestDoubleDtype(TestDType): DTYPE = dtypes.double +class TestHalfDtype(TestDType): + DTYPE = dtypes.half + + +class TestFloatDType(TestDType): + DTYPE = dtypes.float + + +class TestDoubleDtype(TestDType): + DTYPE = dtypes.double + class TestInt8Dtype(TestDType): - DTYPE = dtypes.int8 - @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") - def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + DTYPE = dtypes.int8 + + @unittest.skipIf( + getenv("CUDA", 0) == 1 or getenv("PTX", 0) == 1, + "cuda saturation works differently", + ) + def test_int8_to_uint8_negative(self): + _test_op( + lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), + dtypes.uint8, + [255, 254, 253, 252], + ) + class TestUint8Dtype(TestDType): - DTYPE = dtypes.uint8 - @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") - def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) + DTYPE = dtypes.uint8 -@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") + @unittest.skipIf( + getenv("CUDA", 0) == 1 or getenv("PTX", 0) == 1, + "cuda saturation works differently", + ) + def test_uint8_to_int8_overflow(self): + _test_op( + lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), + dtypes.int8, + [-1, -2, -3, -4], + ) + + +@unittest.skipIf( + Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH" +) class TestBitCast(unittest.TestCase): - def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") - def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432]) - def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0]) + def test_float32_bitcast_to_int32(self): + _test_bitcast( + Tensor([1, 2, 3, 4], dtype=dtypes.float32), + dtypes.int32, + [1065353216, 1073741824, 1077936128, 1082130432], + ) - # NOTE: these are the same as normal casts - def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252]) - def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") + def test_float32_bitcast_to_uint32(self): + _test_bitcast( + Tensor([1, 2, 3, 4], dtype=dtypes.float32), + dtypes.uint32, + [1065353216, 1073741824, 1077936128, 1082130432], + ) - def test_shape_change_bitcast(self): - with self.assertRaises(AssertionError): - _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) + def test_int32_bitcast_to_float32(self): + _test_bitcast( + Tensor( + [1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32 + ), + dtypes.float32, + [1.0, 2.0, 3.0, 4.0], + ) -class TestInt16Dtype(TestDType): DTYPE = dtypes.int16 -class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16 + # NOTE: these are the same as normal casts + def test_int8_bitcast_to_uint8(self): + _test_bitcast( + Tensor([-1, -2, -3, -4], dtype=dtypes.int8), + dtypes.uint8, + [255, 254, 253, 252], + ) -class TestInt32Dtype(TestDType): DTYPE = dtypes.int32 -class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32 + def test_uint8_bitcast_to_int8(self): + _test_bitcast( + Tensor([255, 254, 253, 252], dtype=dtypes.uint8), + dtypes.int8, + [-1, -2, -3, -4], + ) -class TestInt64Dtype(TestDType): DTYPE = dtypes.int64 -class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64 + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_int64_bitcast_to_uint64(self): + _test_bitcast( + Tensor([-1, -2, -3, -4], dtype=dtypes.int64), + dtypes.uint64, + [ + 18446744073709551615, + 18446744073709551614, + 18446744073709551613, + 18446744073709551612, + ], + ) + + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_uint64_bitcast_to_int64(self): + _test_bitcast( + Tensor( + [ + 18446744073709551615, + 18446744073709551614, + 18446744073709551613, + 18446744073709551612, + ], + dtype=dtypes.uint64, + ), + dtypes.int64, + [-1, -2, -3, -4], + ) + + def test_shape_change_bitcast(self): + with self.assertRaises(AssertionError): + _test_bitcast( + Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000] + ) + + +class TestInt16Dtype(TestDType): + DTYPE = dtypes.int16 + + +class TestUint16Dtype(TestDType): + DTYPE = dtypes.uint16 + + +class TestInt32Dtype(TestDType): + DTYPE = dtypes.int32 + + +class TestUint32Dtype(TestDType): + DTYPE = dtypes.uint32 + + +class TestInt64Dtype(TestDType): + DTYPE = dtypes.int64 + + +class TestUint64Dtype(TestDType): + DTYPE = dtypes.uint64 + + +class TestBoolDtype(TestDType): + DTYPE = dtypes.bool -class TestBoolDtype(TestDType): DTYPE = dtypes.bool class TestEqStrDType(unittest.TestCase): - def test_image_ne(self): - if ImageDType is None: raise unittest.SkipTest("no ImageDType support") - assert dtypes.float == dtypes.float32, "float doesn't match?" - assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match" - assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match" - assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches" - assert isinstance(dtypes.imageh((1,2,4)), ImageDType) - def test_ptr_ne(self): - if PtrDType is None: raise unittest.SkipTest("no PtrDType support") - # TODO: is this the wrong behavior? - assert PtrDType(dtypes.float32) == dtypes.float32 - #assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) - #assert PtrDType(dtypes.float32) != dtypes.float32 - def test_strs(self): - if PtrDType is None: raise unittest.SkipTest("no PtrDType support") - self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") - self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float") + def test_image_ne(self): + if ImageDType is None: + raise unittest.SkipTest("no ImageDType support") + assert dtypes.float == dtypes.float32, "float doesn't match?" + assert dtypes.imagef((1, 2, 4)) != dtypes.imageh( + (1, 2, 4) + ), "different image dtype doesn't match" + assert dtypes.imageh((1, 2, 4)) != dtypes.imageh( + (1, 4, 2) + ), "different shape doesn't match" + assert dtypes.imageh((1, 2, 4)) == dtypes.imageh( + (1, 2, 4) + ), "same shape matches" + assert isinstance(dtypes.imageh((1, 2, 4)), ImageDType) -if __name__ == '__main__': - unittest.main() + def test_ptr_ne(self): + if PtrDType is None: + raise unittest.SkipTest("no PtrDType support") + # TODO: is this the wrong behavior? + assert PtrDType(dtypes.float32) == dtypes.float32 + # assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) + # assert PtrDType(dtypes.float32) != dtypes.float32 + + def test_strs(self): + if PtrDType is None: + raise unittest.SkipTest("no PtrDType support") + self.assertEqual(str(dtypes.imagef((1, 2, 4))), "dtypes.imagef((1, 2, 4))") + self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_gc.py b/test/test_gc.py index 49773dd98..7fabdc89d 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -4,34 +4,36 @@ import unittest import numpy as np from tinygrad.tensor import Tensor + def tensors_allocated(): - return sum([isinstance(x, Tensor) for x in gc.get_objects()]) + return sum([isinstance(x, Tensor) for x in gc.get_objects()]) + class TestGC(unittest.TestCase): + def test_gc(self): + a = Tensor.zeros(4, 4, requires_grad=True) + b = Tensor.zeros(4, 4, requires_grad=True) + (a * b).mean().backward() + assert tensors_allocated() > 0 + del a, b + assert tensors_allocated() == 0 - def test_gc(self): - a = Tensor.zeros(4, 4, requires_grad=True) - b = Tensor.zeros(4, 4, requires_grad=True) - (a*b).mean().backward() - assert(tensors_allocated() > 0) - del a,b - assert(tensors_allocated() == 0) + def test_gc_complex(self): + a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) + b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) + assert tensors_allocated() == 2 + (a * b).mean().backward() + assert tensors_allocated() == 4 + del b + assert tensors_allocated() == 2 + b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) + print(tensors_allocated()) + (a * b).mean().backward() + print(tensors_allocated()) + assert tensors_allocated() == 4 + del b + assert tensors_allocated() == 2 - def test_gc_complex(self): - a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) - b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) - assert(tensors_allocated() == 2) - (a*b).mean().backward() - assert(tensors_allocated() == 4) - del b - assert(tensors_allocated() == 2) - b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) - print(tensors_allocated()) - (a*b).mean().backward() - print(tensors_allocated()) - assert(tensors_allocated() == 4) - del b - assert(tensors_allocated() == 2) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/test_hip_rdna3.py b/test/test_hip_rdna3.py index dba226473..681631aae 100644 --- a/test/test_hip_rdna3.py +++ b/test/test_hip_rdna3.py @@ -5,32 +5,36 @@ from tinygrad.helpers import dtypes from examples.beautiful_mnist import Model as MNIST from examples.hlb_cifar10 import SpeedyResNet -@unittest.skipIf(Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1") + +@unittest.skipIf( + Device.DEFAULT != "HIP", reason="testing HIP->rdna3 compilation needs HIP=1" +) class TestHIPCompilationRDNA(unittest.TestCase): - def test_compile_hip_mnist(self): - model = MNIST() + def test_compile_hip_mnist(self): + model = MNIST() - input = Tensor.rand(512,1,28,28) - output = model(input) - output.numpy() + input = Tensor.rand(512, 1, 28, 28) + output = model(input) + output.numpy() - def test_compile_hip_speedyresnet(self): - W = Tensor.rand(12,3,2,2) - model = SpeedyResNet(W) + def test_compile_hip_speedyresnet(self): + W = Tensor.rand(12, 3, 2, 2) + model = SpeedyResNet(W) - input = Tensor.rand(512, 3, 32, 32) - output = model(input) - output.numpy() + input = Tensor.rand(512, 3, 32, 32) + output = model(input) + output.numpy() - def test_compile_hip_speedyresnet_hf(self): - Tensor.default_type = dtypes.float16 + def test_compile_hip_speedyresnet_hf(self): + Tensor.default_type = dtypes.float16 - W = Tensor.rand(12,3,2,2) - model = SpeedyResNet(W) + W = Tensor.rand(12, 3, 2, 2) + model = SpeedyResNet(W) + + input = Tensor.rand(512, 3, 32, 32) + output = model(input) + output.numpy() - input = Tensor.rand(512, 3, 32, 32) - output = model(input) - output.numpy() if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 84227f5b2..7fe83c416 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -3,41 +3,43 @@ import numpy as np from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import ImageDType + @unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU") class TestImageDType(unittest.TestCase): - def test_image_and_back(self): - data = Tensor.randn(9*27*4).realize() - tst = data.numpy() - it = data.cast(dtypes.imagef((9,27,4))).realize() - assert isinstance(it.lazydata.realized.dtype, ImageDType) - np.testing.assert_equal(tst, it.numpy()) + def test_image_and_back(self): + data = Tensor.randn(9 * 27 * 4).realize() + tst = data.numpy() + it = data.cast(dtypes.imagef((9, 27, 4))).realize() + assert isinstance(it.lazydata.realized.dtype, ImageDType) + np.testing.assert_equal(tst, it.numpy()) - def test_image_and_back_wrong_shape(self): - data = Tensor.randn(9*27*4).realize() - tst = data.numpy() - it = data.cast(dtypes.imagef((9,12,4))).realize() - assert not isinstance(it.lazydata.realized.dtype, ImageDType) - np.testing.assert_equal(tst, it.numpy()) + def test_image_and_back_wrong_shape(self): + data = Tensor.randn(9 * 27 * 4).realize() + tst = data.numpy() + it = data.cast(dtypes.imagef((9, 12, 4))).realize() + assert not isinstance(it.lazydata.realized.dtype, ImageDType) + np.testing.assert_equal(tst, it.numpy()) - def test_shrink_load_float(self): - it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize() - imgv = it.numpy() - np.testing.assert_equal(imgv[0:2], it[0:2].numpy()) + def test_shrink_load_float(self): + it = Tensor.randn(4).cast(dtypes.imagef((1, 1, 4))).realize() + imgv = it.numpy() + np.testing.assert_equal(imgv[0:2], it[0:2].numpy()) - def test_mul_stays_image(self): - it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize() - out = (it*2).realize() - assert isinstance(out.lazydata.realized.dtype, ImageDType) + def test_mul_stays_image(self): + it = Tensor.randn(4).cast(dtypes.imagef((1, 1, 4))).realize() + out = (it * 2).realize() + assert isinstance(out.lazydata.realized.dtype, ImageDType) - def test_shrink_max(self): - it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize() - imgv = it.numpy() - np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy()) + def test_shrink_max(self): + it = Tensor.randn(8).cast(dtypes.imagef((1, 2, 4))).realize() + imgv = it.numpy() + np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy()) - def test_shrink_to_float(self): - it = Tensor.randn(4, 4).cast(dtypes.imagef((1,4,4))).realize() - imgv = it.numpy() - np.testing.assert_equal(np.maximum(imgv[:, 0], 0), it[:, 0].relu().realize()) + def test_shrink_to_float(self): + it = Tensor.randn(4, 4).cast(dtypes.imagef((1, 4, 4))).realize() + imgv = it.numpy() + np.testing.assert_equal(np.maximum(imgv[:, 0], 0), it[:, 0].relu().realize()) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_jit.py b/test/test_jit.py index 042c53d23..f2193770c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6,252 +6,306 @@ from test.helpers import assert_jit_cache_len from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit + class TestJit(unittest.TestCase): - def test_simple_jit(self): - @TinyJit - def add(a, b): return (a+b).realize() - for _ in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - c = add(a, b) - np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(add, 1) + def test_simple_jit(self): + @TinyJit + def add(a, b): + return (a + b).realize() - def test_jit_multiple_outputs(self): - @TinyJit - def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize() - for _ in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - c, d, e = f(a, b) - np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5) - np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(f, 3) + for _ in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + c = add(a, b) + np.testing.assert_allclose( + c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(add, 1) + + def test_jit_multiple_outputs(self): + @TinyJit + def f(a, b): + return (a + b).realize(), (a - b).realize(), (a * b).realize() + + for _ in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + c, d, e = f(a, b) + np.testing.assert_allclose( + c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 + ) + np.testing.assert_allclose( + d.numpy(), a.numpy() - b.numpy(), atol=1e-4, rtol=1e-5 + ) + np.testing.assert_allclose( + e.numpy(), a.numpy() * b.numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(f, 3) + + def test_nothing_jitted(self): + @TinyJit + def add(a, b): + return a + b + + with self.assertRaises(AssertionError): + for _ in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + add(a, b) + + def test_jit_shape_mismatch(self): + @TinyJit + def add(a, b): + return (a + b).realize() + + for _ in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + add(a, b) + bad = Tensor.randn(20, 20) + with self.assertRaises(AssertionError): + add(a, bad) + + def test_jit_shape_views_mismatch(self): + @TinyJit + def add(a): + return (a + 1).realize() + + with self.assertRaises(AssertionError): + for i in range(1, 5): + # a has an offset that the kernel doesn't know about + a = Tensor.randn(10, 10).realize()[:, i : i + 2] + add(a) + + def test_jit_duplicate_fail(self): + # the jit doesn't support duplicate arguments + @TinyJit + def add(a, b): + return (a + b).realize() - def test_nothing_jitted(self): - @TinyJit - def add(a, b): return a+b - with self.assertRaises(AssertionError): - for _ in range(5): a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - add(a, b) + with self.assertRaises(AssertionError): + add(a, a) - def test_jit_shape_mismatch(self): - @TinyJit - def add(a, b): return (a+b).realize() - for _ in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - add(a, b) - bad = Tensor.randn(20, 20) - with self.assertRaises(AssertionError): - add(a, bad) + def test_kwargs_jit(self): + @TinyJit + def add_kwargs(first, second): + return (first + second).realize() - def test_jit_shape_views_mismatch(self): - @TinyJit - def add(a): return (a+1).realize() - with self.assertRaises(AssertionError): - for i in range(1,5): - # a has an offset that the kernel doesn't know about - a = Tensor.randn(10, 10).realize()[:, i:i+2] - add(a) + for _ in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + c = add_kwargs(first=a, second=b) + np.testing.assert_allclose( + c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(add_kwargs, 1) - def test_jit_duplicate_fail(self): - # the jit doesn't support duplicate arguments - @TinyJit - def add(a, b): return (a+b).realize() - a = Tensor.randn(10, 10) - with self.assertRaises(AssertionError): - add(a, a) + def test_array_jit(self): + @TinyJit + def add_array(a, arr): + return (a + arr[0]).realize() - def test_kwargs_jit(self): - @TinyJit - def add_kwargs(first, second): return (first+second).realize() - for _ in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - c = add_kwargs(first=a, second=b) - np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(add_kwargs, 1) + for i in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + a.realize(), b.realize() + c = add_array(a, [b]) + if i >= 2: + # should fail once jitted since jit can't handle arrays + np.testing.assert_allclose( + np.any(np.not_equal(c.numpy(), a.numpy() + b.numpy())), + True, + atol=1e-4, + rtol=1e-5, + ) + else: + np.testing.assert_allclose( + c.numpy(), a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(add_array, 1) - def test_array_jit(self): - @TinyJit - def add_array(a, arr): return (a+arr[0]).realize() - for i in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - a.realize(), b.realize() - c = add_array(a, [b]) - if i >= 2: - # should fail once jitted since jit can't handle arrays - np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5) - else: - np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(add_array, 1) + def test_method_jit(self): + class Fun: + def __init__(self): + self.a = Tensor.randn(10, 10) - def test_method_jit(self): - class Fun: - def __init__(self): - self.a = Tensor.randn(10, 10) - @TinyJit - def __call__(self, b:Tensor) -> Tensor: - return (self.a+b).realize() - fun = Fun() - for _ in range(5): - b = Tensor.randn(10, 10) - c = fun(b) - np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(fun.__call__.func.__self__, 1) + @TinyJit + def __call__(self, b: Tensor) -> Tensor: + return (self.a + b).realize() - def test_jit_size1_input(self): - @TinyJit - def f(a, b): return (a+b).realize() - a = Tensor([1, 2, 3]) - for i in range(5): - np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5) - assert_jit_cache_len(f, 1) + fun = Fun() + for _ in range(5): + b = Tensor.randn(10, 10) + c = fun(b) + np.testing.assert_allclose( + c.numpy(), fun.a.numpy() + b.numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(fun.__call__.func.__self__, 1) - def test_jit_output_non_tensor_fail(self): - @TinyJit - def f(a, b, i): return (a+b).realize(), i - output1, output2 = [], [] - expect1, expect2 = [], [] - for i in range(5): - a = Tensor.randn(10, 10) - b = Tensor.randn(10, 10) - o1, o2 = f(a, b, i) - output1.append(o1.numpy().copy()) - output2.append(o2) - expect1.append(a.numpy().copy()+b.numpy().copy()) - expect2.append(i) - np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5) - # the jit only works with Tensor outputs - assert output2 != expect2 - assert_jit_cache_len(f, 1) + def test_jit_size1_input(self): + @TinyJit + def f(a, b): + return (a + b).realize() - def test_jit_random_regen(self): - def f(a, b): - rn = Tensor.randn(*a.shape) - return ((a+b)*rn).realize() - a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed - b = Tensor.randn(10, 10).realize() + a = Tensor([1, 2, 3]) + for i in range(5): + np.testing.assert_allclose( + f(a, Tensor([i])).numpy(), (a + i).numpy(), atol=1e-4, rtol=1e-5 + ) + assert_jit_cache_len(f, 1) - Tensor._seed = 1234 - jf = TinyJit(f) - res = set() - for _ in range(5): - o1 = jf(a, b) - res.add(o1.numpy()[0][0]) - assert len(res) == 5, "All values should be different, rand works in jit." + def test_jit_output_non_tensor_fail(self): + @TinyJit + def f(a, b, i): + return (a + b).realize(), i - Tensor._seed = 1234 - jf2 = TinyJit(f) - res2 = set() - for _ in range(5): - o1 = jf2(a, b) - res2.add(o1.numpy()[0][0]) - assert len(res2) == 5, "All values should be different, rand works in jit." - assert res == res2, "Jit rand is not reproducible with the same seed" + output1, output2 = [], [] + expect1, expect2 = [], [] + for i in range(5): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + o1, o2 = f(a, b, i) + output1.append(o1.numpy().copy()) + output2.append(o2) + expect1.append(a.numpy().copy() + b.numpy().copy()) + expect2.append(i) + np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5) + # the jit only works with Tensor outputs + assert output2 != expect2 + assert_jit_cache_len(f, 1) - Tensor._seed = 3421 - jf3 = TinyJit(f) - res3 = set() - for _ in range(5): - o1 = jf3(a, b) - res3.add(o1.numpy()[0][0]) - assert len(res3) == 5, "All values should be different, rand works in jit." - assert res3 != res2, "Jit rand is diff with diff seeds" + def test_jit_random_regen(self): + def f(a, b): + rn = Tensor.randn(*a.shape) + return ((a + b) * rn).realize() - def test_jit_realization_and_sampling(self): - w = Tensor.eye(5) + a = Tensor.randn( + 10, 10 + ).realize() # realize these before resetting the random seed + b = Tensor.randn(10, 10).realize() - @TinyJit - def foo (x): return w.dot(x).realize() + Tensor._seed = 1234 + jf = TinyJit(f) + res = set() + for _ in range(5): + o1 = jf(a, b) + res.add(o1.numpy()[0][0]) + assert len(res) == 5, "All values should be different, rand works in jit." - arg = [ - Tensor([1,2,3,4,5]), - Tensor([1,3,3,4,6]), - Tensor([1,2,5,4,7]), - Tensor([0,2,3,1,0]), - ] + Tensor._seed = 1234 + jf2 = TinyJit(f) + res2 = set() + for _ in range(5): + o1 = jf2(a, b) + res2.add(o1.numpy()[0][0]) + assert len(res2) == 5, "All values should be different, rand works in jit." + assert res == res2, "Jit rand is not reproducible with the same seed" - Y = [foo(e).numpy() for e in arg] + Tensor._seed = 3421 + jf3 = TinyJit(f) + res3 = set() + for _ in range(5): + o1 = jf3(a, b) + res3.add(o1.numpy()[0][0]) + assert len(res3) == 5, "All values should be different, rand works in jit." + assert res3 != res2, "Jit rand is diff with diff seeds" - foo(Tensor([7,7,7,7,7])) - want = [[1., 2., 3., 4., 5.], - [1., 3., 3., 4., 6.], - [1., 2., 5., 4., 7.], - [0., 2., 3., 1., 0.]] - np.testing.assert_allclose(want, Y) + def test_jit_realization_and_sampling(self): + w = Tensor.eye(5) - def test_jitted_read_assign(self): - class Cache: - def __init__(self): - self.good_cache = Tensor.zeros(1) - self.bad_cache = Tensor.zeros(1) - self.good_jitted = TinyJit(self.good) - self.bad_jitted = TinyJit(self.bad) + @TinyJit + def foo(x): + return w.dot(x).realize() - def good(self, y, cache_v=None): - if cache_v is not None: - self.good_cache.assign(cache_v+1-1).realize() - return (self.good_cache + y).realize() # need + y to provide inputs to JIT + arg = [ + Tensor([1, 2, 3, 4, 5]), + Tensor([1, 3, 3, 4, 6]), + Tensor([1, 2, 5, 4, 7]), + Tensor([0, 2, 3, 1, 0]), + ] - def bad(self, y, cache_v=None): - if cache_v is not None: - self.bad_cache.assign(cache_v).realize() - return (self.bad_cache + y).realize() + Y = [foo(e).numpy() for e in arg] - cache = Cache() - np.testing.assert_equal([0], cache.good_cache.numpy()) - np.testing.assert_equal([0], cache.bad_cache.numpy()) + foo(Tensor([7, 7, 7, 7, 7])) + want = [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 3.0, 3.0, 4.0, 6.0], + [1.0, 2.0, 5.0, 4.0, 7.0], + [0.0, 2.0, 3.0, 1.0, 0.0], + ] + np.testing.assert_allclose(want, Y) - zero = Tensor([0]) - one = Tensor([1]) - two = Tensor([2]) + def test_jitted_read_assign(self): + class Cache: + def __init__(self): + self.good_cache = Tensor.zeros(1) + self.bad_cache = Tensor.zeros(1) + self.good_jitted = TinyJit(self.good) + self.bad_jitted = TinyJit(self.bad) - # save [1] in the caches - cache.good(zero, one) - cache.bad(zero, one) - np.testing.assert_equal([1], cache.good_cache.numpy()) - np.testing.assert_equal([1], cache.bad_cache.numpy()) + def good(self, y, cache_v=None): + if cache_v is not None: + self.good_cache.assign(cache_v + 1 - 1).realize() + return ( + self.good_cache + y + ).realize() # need + y to provide inputs to JIT - for i in range(5): - cache.good_jitted(zero) - cache.bad_jitted(zero) + def bad(self, y, cache_v=None): + if cache_v is not None: + self.bad_cache.assign(cache_v).realize() + return (self.bad_cache + y).realize() - # verify the jitted calls read 1 from the cache - np.testing.assert_equal([1], cache.good_jitted(zero).numpy()) - np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) + cache = Cache() + np.testing.assert_equal([0], cache.good_cache.numpy()) + np.testing.assert_equal([0], cache.bad_cache.numpy()) - # save [2] in the caches - cache.good(zero, two) - cache.bad(zero, two) - np.testing.assert_equal([2], cache.good_cache) - np.testing.assert_equal([2], cache.bad_cache) + zero = Tensor([0]) + one = Tensor([1]) + two = Tensor([2]) - # verify the jitted calls read 2 from the cache - np.testing.assert_equal([2], cache.good_jitted(zero).numpy()) - # but the bad_jitted doesn't! - np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) + # save [1] in the caches + cache.good(zero, one) + cache.bad(zero, one) + np.testing.assert_equal([1], cache.good_cache.numpy()) + np.testing.assert_equal([1], cache.bad_cache.numpy()) - assert_jit_cache_len(cache.good_jitted, 1) - assert_jit_cache_len(cache.bad_jitted, 1) + for i in range(5): + cache.good_jitted(zero) + cache.bad_jitted(zero) - def test_jit_buffer_behavior(self): - @TinyJit - def foo(x) -> Tensor: return x.sum().realize() + # verify the jitted calls read 1 from the cache + np.testing.assert_equal([1], cache.good_jitted(zero).numpy()) + np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) - result_1 = foo(Tensor([1] * 2)) - result_2 = foo(Tensor([2] * 2)) - result_3 = foo(Tensor([3] * 2)) + # save [2] in the caches + cache.good(zero, two) + cache.bad(zero, two) + np.testing.assert_equal([2], cache.good_cache) + np.testing.assert_equal([2], cache.bad_cache) - # expect the buffer to share underlying buffer - np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5) - np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5) - np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5) + # verify the jitted calls read 2 from the cache + np.testing.assert_equal([2], cache.good_jitted(zero).numpy()) + # but the bad_jitted doesn't! + np.testing.assert_equal([1], cache.bad_jitted(zero).numpy()) -if __name__ == '__main__': - unittest.main() + assert_jit_cache_len(cache.good_jitted, 1) + assert_jit_cache_len(cache.bad_jitted, 1) + + def test_jit_buffer_behavior(self): + @TinyJit + def foo(x) -> Tensor: + return x.sum().realize() + + result_1 = foo(Tensor([1] * 2)) + result_2 = foo(Tensor([2] * 2)) + result_3 = foo(Tensor([3] * 2)) + + # expect the buffer to share underlying buffer + np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5) + np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5) + np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_kernel_cache.py b/test/test_kernel_cache.py index 96a38cf8f..53e75d0d4 100644 --- a/test/test_kernel_cache.py +++ b/test/test_kernel_cache.py @@ -6,49 +6,54 @@ from tinygrad.tensor import Tensor from tinygrad import Device from tinygrad.helpers import diskcache + def generate_random_string(length=16): - alphabet = string.ascii_letters + string.digits - return ''.join(secrets.choice(alphabet) for _ in range(length)) + alphabet = string.ascii_letters + string.digits + return "".join(secrets.choice(alphabet) for _ in range(length)) + compile_call_count = 0 + @diskcache -def helper_test_compile(prg:str) -> bytes: - global compile_call_count - compile_call_count += 1 - return prg.encode() +def helper_test_compile(prg: str) -> bytes: + global compile_call_count + compile_call_count += 1 + return prg.encode() + class TestKernelCache(unittest.TestCase): - def test_compile_cache(self): - prg1 = generate_random_string(64) + "a" - prg2 = generate_random_string(64) + "b" - cold_compile_res = helper_test_compile(prg1) - warm_compile_res = helper_test_compile(prg1) - assert cold_compile_res == warm_compile_res == prg1.encode() - assert compile_call_count == 1 + def test_compile_cache(self): + prg1 = generate_random_string(64) + "a" + prg2 = generate_random_string(64) + "b" + cold_compile_res = helper_test_compile(prg1) + warm_compile_res = helper_test_compile(prg1) + assert cold_compile_res == warm_compile_res == prg1.encode() + assert compile_call_count == 1 - prg2_res = helper_test_compile(prg2) - assert prg2_res == prg2.encode() - assert compile_call_count == 2 + prg2_res = helper_test_compile(prg2) + assert prg2_res == prg2.encode() + assert compile_call_count == 2 - def test_kernel_cache_in_action(self): - if Device.DEFAULT not in ["CLANG"]: - self.skipTest("No custom kernel cache is implemented") + def test_kernel_cache_in_action(self): + if Device.DEFAULT not in ["CLANG"]: + self.skipTest("No custom kernel cache is implemented") - a = Tensor.rand(4,4) - b = Tensor.rand(4,4) - x = a + b - x.realize() + a = Tensor.rand(4, 4) + b = Tensor.rand(4, 4) + x = a + b + x.realize() - orig_compile_func = Device['CLANG'].compiler - Device['CLANG'].compiler = None # making it not callable + orig_compile_func = Device["CLANG"].compiler + Device["CLANG"].compiler = None # making it not callable - a1 = Tensor.rand(4,4) - b1 = Tensor.rand(4,4) - x1 = a1 + b1 - x1.realize() # Same kernel should be from cache. + a1 = Tensor.rand(4, 4) + b1 = Tensor.rand(4, 4) + x1 = a1 + b1 + x1.realize() # Same kernel should be from cache. + + Device["CLANG"].compiler = orig_compile_func - Device['CLANG'].compiler = orig_compile_func if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 1cfefb008..c8cfc91b0 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -6,68 +6,72 @@ from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector + class TestLazyBuffer(unittest.TestCase): - @unittest.skip("it doesn't work like this anymore") - def test_fromcpu_buffer_sharing(self): - a = np.arange(8) - assert LazyBuffer.fromCPU(a).realized._buf is a + @unittest.skip("it doesn't work like this anymore") + def test_fromcpu_buffer_sharing(self): + a = np.arange(8) + assert LazyBuffer.fromCPU(a).realized._buf is a - def test_fromcpu_shape_tracker(self): - def helper(a: np.ndarray): - print(a.shape, a.strides, a.flags.c_contiguous) - b = LazyBuffer.fromCPU(a) - #assert b.st.contiguous == a.flags.c_contiguous - assert b.st.shape == a.shape - np.testing.assert_equal(a, Tensor(b).numpy()) + def test_fromcpu_shape_tracker(self): + def helper(a: np.ndarray): + print(a.shape, a.strides, a.flags.c_contiguous) + b = LazyBuffer.fromCPU(a) + # assert b.st.contiguous == a.flags.c_contiguous + assert b.st.shape == a.shape + np.testing.assert_equal(a, Tensor(b).numpy()) - for ndims in range(1, 4): - a = np.random.randn(*(4,)*ndims).astype(np.float32) - for stride in [-2, 1, 2]: - for start in [0, 1]: - helper(a[(slice(start, None, stride),)*ndims]) + for ndims in range(1, 4): + a = np.random.randn(*(4,) * ndims).astype(np.float32) + for stride in [-2, 1, 2]: + for start in [0, 1]: + helper(a[(slice(start, None, stride),) * ndims]) - def test_shuffle_pad_ops_cmpeq(self): - y = Tensor([1]).cat(Tensor([1]) == 0).numpy() - z = Tensor([1, 0]).numpy() - np.testing.assert_allclose(y, z) + def test_shuffle_pad_ops_cmpeq(self): + y = Tensor([1]).cat(Tensor([1]) == 0).numpy() + z = Tensor([1, 0]).numpy() + np.testing.assert_allclose(y, z) - def test_shuffle_pad_ops_div(self): - y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy() - z = Tensor([1, 0.5]).numpy() - np.testing.assert_allclose(y, z) + def test_shuffle_pad_ops_div(self): + y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy() + z = Tensor([1, 0.5]).numpy() + np.testing.assert_allclose(y, z) - def test_shuffle_pad_ops_log(self): - y = Tensor([1]).cat(Tensor([1]).log()).numpy() - z = Tensor([1, 0]).numpy() - np.testing.assert_allclose(y, z) + def test_shuffle_pad_ops_log(self): + y = Tensor([1]).cat(Tensor([1]).log()).numpy() + z = Tensor([1, 0]).numpy() + np.testing.assert_allclose(y, z) - def test_shuffle_pad_ops_exp(self): - y = Tensor([1]).cat(Tensor([1]).exp()).numpy() - z = Tensor([1, np.e]).numpy() - np.testing.assert_allclose(y, z) + def test_shuffle_pad_ops_exp(self): + y = Tensor([1]).cat(Tensor([1]).exp()).numpy() + z = Tensor([1, np.e]).numpy() + np.testing.assert_allclose(y, z) - @unittest.skipUnless(Device.DEFAULT in ["METAL", "CUDA", "GPU"], "Only GPU backends supports cache") - def test_children_count(self): - a = Tensor.ones(8,8,8) - d1 = a.sum((0)) - d2 = a.sum((0)).reshape(32,2) # noqa: F841 - assert len(d1.lazydata.op.src[0].children) == 1 - in1 = d1.reshape(16,4) - d3 = in1.reshape(8,8) - assert len(d3.lazydata.op.src[0].children) == 2 + @unittest.skipUnless( + Device.DEFAULT in ["METAL", "CUDA", "GPU"], "Only GPU backends supports cache" + ) + def test_children_count(self): + a = Tensor.ones(8, 8, 8) + d1 = a.sum((0)) + d2 = a.sum((0)).reshape(32, 2) # noqa: F841 + assert len(d1.lazydata.op.src[0].children) == 1 + in1 = d1.reshape(16, 4) + d3 = in1.reshape(8, 8) + assert len(d3.lazydata.op.src[0].children) == 2 + + CacheCollector.start() + l = Tensor.ones(8, 8) + r = Tensor.ones(8, 8) + dd = d1 + l + dd.realize() + de = d3 + r + de.realize() + cache = CacheCollector.finish() + assert len(cache) == 3 + assert cache[0].prg.name.startswith("r_") # Reduce should not merged 2 times. + assert cache[1].prg.name.startswith("E_") + assert cache[2].prg.name.startswith("E_") - CacheCollector.start() - l = Tensor.ones(8,8) - r = Tensor.ones(8,8) - dd = d1 + l - dd.realize() - de = d3 + r - de.realize() - cache = CacheCollector.finish() - assert len(cache) == 3 - assert cache[0].prg.name.startswith("r_") # Reduce should not merged 2 times. - assert cache[1].prg.name.startswith("E_") - assert cache[2].prg.name.startswith("E_") if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_lazyop.py b/test/test_lazyop.py index 6ab1e96ec..6e29149c3 100644 --- a/test/test_lazyop.py +++ b/test/test_lazyop.py @@ -3,7 +3,16 @@ from tinygrad.tensor import Tensor # stuff needed to unpack a kernel # ruff: noqa: F401 -from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.ops import ( + LazyOp, + TernaryOps, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, +) from tinygrad.lazy import LazyBuffer from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -11,23 +20,27 @@ from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable import numpy as np import time -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") + class TestLazyOp(unittest.TestCase): - def test_lazyop_str(self): - t = Tensor.rand(10) + Tensor.rand(10) - s = t.lazydata.schedule() - ast = s[-1].ast - ast_remade = eval(str(ast)) - self.assertEqual(ast, ast_remade) + def test_lazyop_str(self): + t = Tensor.rand(10) + Tensor.rand(10) + s = t.lazydata.schedule() + ast = s[-1].ast + ast_remade = eval(str(ast)) + self.assertEqual(ast, ast_remade) - def test_selfreferential_speed(self): - st = time.monotonic() - for i in range(25): - p = LazyBuffer.fromCPU(np.array([1])) - for _ in range(i): p = p.e(BinaryOps.ADD, p) - # sanity check if caching works this should be way faster - assert time.monotonic() -st < 0.5, f"{i}" + def test_selfreferential_speed(self): + st = time.monotonic() + for i in range(25): + p = LazyBuffer.fromCPU(np.array([1])) + for _ in range(i): + p = p.e(BinaryOps.ADD, p) + # sanity check if caching works this should be way faster + assert time.monotonic() - st < 0.5, f"{i}" -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index dc0fe49a1..33b59e69a 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,538 +12,844 @@ from tinygrad.jit import CacheCollector from tinygrad.realize import run_schedule from tinygrad.helpers import dtypes, prod + class TestLinearizer(unittest.TestCase): - def test_arg_dedup(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled supports cache") - a, b = Tensor.randn(4), Tensor.randn(4) - np_a, np_b = a.numpy(), b.numpy() - CacheCollector.start() - c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize() - rawbufs = CacheCollector.finish()[0].rawbufs - assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized} - np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) - np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) + def test_arg_dedup(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled supports cache") + a, b = Tensor.randn(4), Tensor.randn(4) + np_a, np_b = a.numpy(), b.numpy() + CacheCollector.start() + c = ( + (a.shrink(((0, 2),)) - a.shrink(((2, 4),))) + - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))) + ).realize() + rawbufs = CacheCollector.finish()[0].rawbufs + assert len(rawbufs) == 3 and set(rawbufs[1:]) == { + a.lazydata.realized, + b.lazydata.realized, + } + np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) + np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) - def test_load_dedup(self): - # for different leaves in the AST, the same loads may occur. + def test_load_dedup(self): + # for different leaves in the AST, the same loads may occur. - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - a = Tensor.randn(4).realize() - # these are of size 3 to avoid float4 coalesce - r = a[:-1] + a[1:] + a = Tensor.randn(4).realize() + # these are of size 3 to avoid float4 coalesce + r = a[:-1] + a[1:] - k = Linearizer(r.lazydata.schedule()[-1].ast) - k.upcast() - k.linearize() - num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD]) - assert num_loads <= 4, "more load uops than needed" - assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" + k = Linearizer(r.lazydata.schedule()[-1].ast) + k.upcast() + k.linearize() + num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD]) + assert num_loads <= 4, "more load uops than needed" + assert ( + num_loads >= 4 + ), "unexpected number of uops, maybe this test needs updating?" - def test_upcast_cse(self): - # when upcasting, within a subtree, there may be common expressions. + def test_upcast_cse(self): + # when upcasting, within a subtree, there may be common expressions. - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() - r = a.expand([2]) + b.expand([2]) + a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() + r = a.expand([2]) + b.expand([2]) - k = Linearizer(r.lazydata.schedule()[-1].ast) - k.upcast() - k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) - assert num_ops <= 1, "more alu uops than needed" + k = Linearizer(r.lazydata.schedule()[-1].ast) + k.upcast() + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + assert num_ops <= 1, "more alu uops than needed" - def test_zero_fold(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + def test_zero_fold(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() - r = Tensor.stack([a, b]) + a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() + r = Tensor.stack([a, b]) - k = Linearizer(r.lazydata.schedule()[-1].ast) - k.upcast() - k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) - assert num_ops == 0, "more alu uops than needed" + k = Linearizer(r.lazydata.schedule()[-1].ast) + k.upcast() + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + assert num_ops == 0, "more alu uops than needed" - @unittest.skip("constant folding not supported yet") - def test_constant_fold(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + @unittest.skip("constant folding not supported yet") + def test_constant_fold(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - a, b = Tensor(2), Tensor(3) - r = a * b + a, b = Tensor(2), Tensor(3) + r = a * b - k = Linearizer(r.lazydata.schedule()[-1][0]) - k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) - assert num_ops <= 0, "more load or alu uops than needed" + k = Linearizer(r.lazydata.schedule()[-1][0]) + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) + assert num_ops <= 0, "more load or alu uops than needed" - def test_tensor_cores(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") - if Device.DEFAULT not in tensor_cores: - self.skipTest("No tensor cores for device") + def test_tensor_cores(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + if Device.DEFAULT not in tensor_cores: + self.skipTest("No tensor cores for device") - for tc in tensor_cores[Device.DEFAULT]: - if tc.arch is not None and tc.arch != os.uname().machine: continue - a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in) - np_a, np_b = a.numpy(), b.numpy() - if tc.dtype_out != tc.dtype_in: - r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2) - else: - r = a @ b - realized_ast, _ = helper_realized_ast(r) - k = Linearizer(realized_ast) - k.apply_tensor_cores(1) - k.linearize() - assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered" - np_c = np_a @ np_b - np.testing.assert_allclose(np_c, r.numpy(), atol=5e-3, rtol=1e-4) + for tc in tensor_cores[Device.DEFAULT]: + if tc.arch is not None and tc.arch != os.uname().machine: + continue + a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand( + tc.dims[2], tc.dims[1], dtype=tc.dtype_in + ) + np_a, np_b = a.numpy(), b.numpy() + if tc.dtype_out != tc.dtype_in: + r = ( + ( + a.reshape(tc.dims[0], 1, tc.dims[2]) + * b.permute(1, 0).reshape(1, tc.dims[1], tc.dims[2]) + ) + .cast(tc.dtype_out) + .sum(axis=2) + ) + else: + r = a @ b + realized_ast, _ = helper_realized_ast(r) + k = Linearizer(realized_ast) + k.apply_tensor_cores(1) + k.linearize() + assert ( + len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1 + ), "tensor core not triggered" + np_c = np_a @ np_b + np.testing.assert_allclose(np_c, r.numpy(), atol=5e-3, rtol=1e-4) - def test_limit_dims_to_max_5d_global(self): - t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 - sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] - assert len(sched) == 1 - lin = Linearizer(sched[0].ast) - assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9) - lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16]) + def test_limit_dims_to_max_5d_global(self): + t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 + sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] + assert len(sched) == 1 + lin = Linearizer(sched[0].ast) + assert lin.full_shape[: lin.global_dims] == (5, 6, 7, 8, 9) + lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16]) - def test_sum_collapse(self): - t = Tensor.ones(256,256).sum() - sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] - assert len(sched) == 1 - lin = Linearizer(sched[0].ast) - assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" + def test_sum_collapse(self): + t = Tensor.ones(256, 256).sum() + sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] + assert len(sched) == 1 + lin = Linearizer(sched[0].ast) + assert not any( + u.uop == UOps.LOOP for u in lin.linearize().uops + ), "found loop in sum collapse" - def test_simplify_uop(self): - def helper_test_simplify(uop, dtype, vin, arg=None): - ast = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)))) - ast = LazyOp(BufferOps.STORE, (ast,), MemBuffer(0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)))) - lin = Linearizer(ast=ast) # this is a dummy ast + def test_simplify_uop(self): + def helper_test_simplify(uop, dtype, vin, arg=None): + ast = LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=42, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(), + strides=(), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ) + ast = LazyOp( + BufferOps.STORE, + (ast,), + MemBuffer( + 0, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(), + strides=(), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ) + lin = Linearizer(ast=ast) # this is a dummy ast - lin.uops = [] - return lin.uop(uop, dtype, vin, arg, cachable=False) + lin.uops = [] + return lin.uop(uop, dtype, vin, arg, cachable=False) - c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) - assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), arg=TernaryOps.WHERE) == c0 + c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) + assert ( + helper_test_simplify( + UOps.ALU, + dtypes.bool, + vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c0), + arg=TernaryOps.WHERE, + ) + == c0 + ) - c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) - c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0) - assert helper_test_simplify(UOps.ALU, dtypes.bool, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), arg=TernaryOps.WHERE).uop == UOps.ALU + c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) + c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0) + assert ( + helper_test_simplify( + UOps.ALU, + dtypes.bool, + vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), + arg=TernaryOps.WHERE, + ).uop + == UOps.ALU + ) + + +def helper_realized_ast(r: Tensor): + s = r.lazydata.schedule() + run_schedule(s[:-1]) # run all kernels except the last one + # now all input LazyBuffers buffers in s[-1] should be realized + output_buffer = Buffer( + s[-1].out.device, + prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), + s[-1].out.dtype, + **s[-1].out._device_extra_args() + ) # allocate an output buffer + return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs] -def helper_realized_ast(r:Tensor): - s = r.lazydata.schedule() - run_schedule(s[:-1]) # run all kernels except the last one - # now all input LazyBuffers buffers in s[-1] should be realized - output_buffer = Buffer(s[-1].out.device, prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype, **s[-1].out._device_extra_args()) # allocate an output buffer - return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs] class TestFloat4(unittest.TestCase): - def setUp(self): - if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.supports_float4: - self.skipTest("Device does not support float4") + def setUp(self): + if ( + not isinstance(Device[Device.DEFAULT], Compiled) + or not Device[Device.DEFAULT].linearizer_opts.supports_float4 + ): + self.skipTest("Device does not support float4") - @staticmethod - def count_float4(k): - return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), - len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)])) + @staticmethod + def count_float4(k): + return ( + len( + [ + uop + for uop in k.uops + if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4) + ] + ), + len( + [ + uop + for uop in k.uops + if uop.uop == UOps.STORE + and len(uop.vin) == 3 + and uop.vin[2].dtype == dtypes.float.vec(4) + ] + ), + ) - # TODO: express opts below as auto opts + # TODO: express opts below as auto opts - def test_float4_basic(self): - a = Tensor.rand(2, 8).realize() - b = Tensor.rand(2, 8).realize() - c = a + b + def test_float4_basic(self): + a = Tensor.rand(2, 8).realize() + b = Tensor.rand(2, 8).realize() + c = a + b - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.hand_coded_optimizations() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.hand_coded_optimizations() + k.linearize() - assert TestFloat4.count_float4(k) == (2, 1) + assert TestFloat4.count_float4(k) == (2, 1) - def test_float4_multidim(self): - a = Tensor.rand(2, 8).realize() - b = Tensor.rand(2, 8).realize() - c = a + b + def test_float4_multidim(self): + a = Tensor.rand(2, 8).realize() + b = Tensor.rand(2, 8).realize() + c = a + b - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.shift_to(0, 4) # float4 dimension - k.shift_to(0, 2, insert_before=k.shape_len-1) - k.upcast() - k.upcast() - k.local_dims += 1 - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.shift_to(0, 4) # float4 dimension + k.shift_to(0, 2, insert_before=k.shape_len - 1) + k.upcast() + k.upcast() + k.local_dims += 1 + k.linearize() - assert TestFloat4.count_float4(k) == (4, 2) + assert TestFloat4.count_float4(k) == (4, 2) - def test_float4_unaligned_load(self): - a = Tensor.rand(9).realize().shrink(((1, 9),)) - b = Tensor.rand(9).realize().shrink(((1, 9),)) - c = a + b + def test_float4_unaligned_load(self): + a = Tensor.rand(9).realize().shrink(((1, 9),)) + b = Tensor.rand(9).realize().shrink(((1, 9),)) + c = a + b - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.hand_coded_optimizations() # implicit trigger float4 dim - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.hand_coded_optimizations() # implicit trigger float4 dim + k.linearize() - assert TestFloat4.count_float4(k) == (0, 1) + assert TestFloat4.count_float4(k) == (0, 1) - def test_float4_multidim_unaligned_load(self): - a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) - b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) - c = a + b + def test_float4_multidim_unaligned_load(self): + a = ( + Tensor.rand(2, 9) + .realize() + .shrink( + ( + (0, 2), + (1, 9), + ) + ) + ) + b = ( + Tensor.rand(2, 9) + .realize() + .shrink( + ( + (0, 2), + (1, 9), + ) + ) + ) + c = a + b - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim - k.upcast() - k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1) - k.upcast() - k.local_dims += 1 - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.shift_to(len(k.full_unupcasted_shape) - 1, 4) # manual trigger float4 dim + k.upcast() + k.shift_to(len(k.full_unupcasted_shape) - 1, 2, insert_before=k.shape_len - 1) + k.upcast() + k.local_dims += 1 + k.linearize() - assert TestFloat4.count_float4(k) == (0, 2) + assert TestFloat4.count_float4(k) == (0, 2) - def test_float4_sometimes_unaligned(self): - a = Tensor.rand(1, 1, 8).realize() - b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # only the first and last conv dot products are aligned in a, and b is never aligned, so no - # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) + def test_float4_sometimes_unaligned(self): + a = Tensor.rand(1, 1, 8).realize() + b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # only the first and last conv dot products are aligned in a, and b is never aligned, so no + # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.upcast() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.upcast() + k.linearize() - assert TestFloat4.count_float4(k) == (0, 0) + assert TestFloat4.count_float4(k) == (0, 0) - def test_float4_multidim_sometimes_unaligned(self): - a = Tensor.rand(1, 1, 7).realize() - b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) - c = a.conv2d(b) - # the first conv dot product is aligned in a. If we upcast the output and reduce - # dimension, then we could do float4 for only that one set of loads, but we currently - # don't. + def test_float4_multidim_sometimes_unaligned(self): + a = Tensor.rand(1, 1, 7).realize() + b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) + c = a.conv2d(b) + # the first conv dot product is aligned in a. If we upcast the output and reduce + # dimension, then we could do float4 for only that one set of loads, but we currently + # don't. - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.upcast() - k.upcast() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.upcast() + k.upcast() + k.linearize() - assert TestFloat4.count_float4(k) == (0, 1) + assert TestFloat4.count_float4(k) == (0, 1) - def test_float4_noncontiguous(self): - a = Tensor.rand(4, 2).realize() - b = Tensor.rand(4, 2).realize() - c = a + b + def test_float4_noncontiguous(self): + a = Tensor.rand(4, 2).realize() + b = Tensor.rand(4, 2).realize() + c = a + b - # we will upcast the top axis of sz 4. they should not be coalesced into float4, - # since the top axis is not contiguous. + # we will upcast the top axis of sz 4. they should not be coalesced into float4, + # since the top axis is not contiguous. - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.shift_to(0, 4, top=True) # top axes are float4 axes - k.upcast() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.shift_to(0, 4, top=True) # top axes are float4 axes + k.upcast() + k.linearize() - assert TestFloat4.count_float4(k) == (0, 0) + assert TestFloat4.count_float4(k) == (0, 0) - def test_float4_expand(self): - a = Tensor.rand(9).realize().shrink(((1, 9),)) - b = Tensor.rand(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) - c = a + b + def test_float4_expand(self): + a = Tensor.rand(9).realize().shrink(((1, 9),)) + b = Tensor.rand(2).realize().reshape((2, 1)).expand((2, 4)).reshape((8,)) + c = a + b - # we will upcast the top axis of sz 4. they should not be coalesced into float4, - # since the top axis is not contiguous. + # we will upcast the top axis of sz 4. they should not be coalesced into float4, + # since the top axis is not contiguous. - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.shift_to(0, 4) # float4 axis - k.upcast() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.shift_to(0, 4) # float4 axis + k.upcast() + k.linearize() - assert TestFloat4.count_float4(k) == (0, 1) + assert TestFloat4.count_float4(k) == (0, 1) - def test_float4_heterogeneous(self): - a = Tensor.rand(8).realize() - b = Tensor.rand(9).realize().shrink(((1, 9),)) - c = a + b + def test_float4_heterogeneous(self): + a = Tensor.rand(8).realize() + b = Tensor.rand(9).realize().shrink(((1, 9),)) + c = a + b - # should float4 b but not a + # should float4 b but not a - s = c.lazydata.schedule()[0] - k = Linearizer(s.ast) - k.shift_to(0, 4) # float4 axis - k.upcast() - k.linearize() + s = c.lazydata.schedule()[0] + k = Linearizer(s.ast) + k.shift_to(0, 4) # float4 axis + k.upcast() + k.linearize() + + assert TestFloat4.count_float4(k) == (1, 1) - assert TestFloat4.count_float4(k) == (1, 1) class TestHandCodedOpts(unittest.TestCase): - def setUp(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Device does not use linearizer") + def setUp(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Device does not use linearizer") - def test_masked_upcast(self): - layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)]) - layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20)) + def test_masked_upcast(self): + layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)]) + layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20)) - s = layer_2.lazydata.schedule()[-1] - k = Linearizer(s.ast) + s = layer_2.lazydata.schedule()[-1] + k = Linearizer(s.ast) + k.hand_coded_optimizations() + assert len(k.bufs) == 6 # make sure all ops are done in one kernel + # masked upcast should upcast masked axis of size 7 + # masked upcast should not upcast large (20) last axis + # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous + assert k.upcasted == 1 and k.full_shape[-1] == 7 + + def test_masked_upcast_wino(self): + monster = Tensor.stack( + [Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)] + ) + + s = monster.lazydata.schedule()[-1] + k = Linearizer(s.ast) + k.hand_coded_optimizations() + assert len(k.bufs) == 37 # make sure all ops are done in one kernel + # should upcast the two Tensor.stacks + assert ( + k.upcasted >= 2 + and k.full_shape[k.shape_len - k.upcasted : k.shape_len].count(6) == 2 + ) + + def test_masked_upcast_wino_full(self): + old_wino = Tensor.wino + Tensor.wino = True + x, w = ( + Tensor.rand(1, 4, 9, 9, requires_grad=True).realize(), + Tensor.rand(4, 4, 3, 3, requires_grad=True).realize(), + ) + out = Tensor.conv2d(x, w, padding=1) + upcasts = [] + # collect upcasts of tile transform kernels + for i, si in enumerate(out.lazydata.schedule()): + k = Linearizer(si.ast) + k.hand_coded_optimizations() + if k.reduceop is not None: + continue # not a tile transform kernel (there is a gemm reduce kernel) + if len(k.bufs) < 100: + continue # not a tile transform kernel (there's a permute kernel at the end) + upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted : k.shape_len])) + assert len(upcasts) == 3 # 3 transformation matrices + assert upcasts.count((6, 6)) == 2 and upcasts.count((4, 4)) == 1 + + out.mean().backward() + for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule(): + k = Linearizer(si.ast) + k.hand_coded_optimizations() + k.linearize() + if len(k.bufs) < 20: + continue # not a tile transform kernel + # heuristic number to make sure that at least some upcasts but not too many upcasts are being done + assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted : k.shape_len]) <= 49 + + Tensor.wino = old_wino + + def test_masked_upcast_many(self): + layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4)) + layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4)) + layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4)) + + s = layer_3.lazydata.schedule()[-1] + k = Linearizer(s.ast) + k.hand_coded_optimizations() + assert len(k.bufs) == 5 # make sure all ops are done in one kernel + # check that we don't do too many upcasts + assert prod(k.full_shape[k.shape_len - k.upcasted : k.shape_len]) <= 49 + + +def helper_linearizer_opt(r: Tensor, opts=[], apply_tc=False): + wanna_output = None + realized_ast, real_bufs = helper_realized_ast(r) + + def check_opt(opts, create_k, to_prg): + k = create_k() + if apply_tc: + k.apply_tensor_cores(1, opts) + else: + for opt in opts: + k.apply_opt(opt) + prg = to_prg(k) + real_bufs[0].copyin( + np.zeros((real_bufs[0].size,), dtype=real_bufs[0].dtype.np).data + ) # Zero to check that all values are filled + prg.exec(real_bufs) + np.testing.assert_allclose( + wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4 + ) + + # Get baseline, which is not optimized at all. + k = Linearizer(realized_ast) + prg = Device[Device.DEFAULT].to_program(k) + prg.exec(real_bufs) + wanna_output = real_bufs[0].toCPU().copy() + + # Check correctness of handcoded optimiztions. + k = Linearizer(realized_ast) k.hand_coded_optimizations() - assert len(k.bufs) == 6 # make sure all ops are done in one kernel - # masked upcast should upcast masked axis of size 7 - # masked upcast should not upcast large (20) last axis - # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous - assert k.upcasted == 1 and k.full_shape[-1] == 7 - - def test_masked_upcast_wino(self): - monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) - - s = monster.lazydata.schedule()[-1] - k = Linearizer(s.ast) - k.hand_coded_optimizations() - assert len(k.bufs) == 37 # make sure all ops are done in one kernel - # should upcast the two Tensor.stacks - assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 - - def test_masked_upcast_wino_full(self): - old_wino = Tensor.wino - Tensor.wino = True - x,w = Tensor.rand(1,4,9,9, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() - out = Tensor.conv2d(x,w, padding=1) - upcasts = [] - # collect upcasts of tile transform kernels - for i, si in enumerate(out.lazydata.schedule()): - k = Linearizer(si.ast) - k.hand_coded_optimizations() - if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) - if len(k.bufs) < 100: continue # not a tile transform kernel (there's a permute kernel at the end) - upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len])) - assert len(upcasts) == 3 # 3 transformation matrices - assert upcasts.count((6, 6)) == 2 and upcasts.count((4, 4)) == 1 - - out.mean().backward() - for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule(): - k = Linearizer(si.ast) - k.hand_coded_optimizations() - k.linearize() - if len(k.bufs) < 20: continue # not a tile transform kernel - # heuristic number to make sure that at least some upcasts but not too many upcasts are being done - assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 49 - - Tensor.wino = old_wino - - def test_masked_upcast_many(self): - layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4)) - layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4)) - layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4)) - - s = layer_3.lazydata.schedule()[-1] - k = Linearizer(s.ast) - k.hand_coded_optimizations() - assert len(k.bufs) == 5 # make sure all ops are done in one kernel - # check that we don't do too many upcasts - assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49 - -def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False): - wanna_output = None - realized_ast, real_bufs = helper_realized_ast(r) - - def check_opt(opts, create_k, to_prg): - k = create_k() - if apply_tc: - k.apply_tensor_cores(1, opts) - else: - for opt in opts: - k.apply_opt(opt) - prg = to_prg(k) - real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled + prg = Device[Device.DEFAULT].to_program(k) + real_bufs[0].copyin( + np.zeros((real_bufs[0].size,), dtype=real_bufs[0].dtype.np).data + ) # Zero to check that all values are filled prg.exec(real_bufs) np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) + for x in opts: # Check custom transformations if any. + check_opt( + x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program + ) - # Get baseline, which is not optimized at all. - k = Linearizer(realized_ast) - prg = Device[Device.DEFAULT].to_program(k) - prg.exec(real_bufs) - wanna_output = real_bufs[0].toCPU().copy() - - # Check correctness of handcoded optimiztions. - k = Linearizer(realized_ast) - k.hand_coded_optimizations() - prg = Device[Device.DEFAULT].to_program(k) - real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled - prg.exec(real_bufs) - np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4) - for x in opts: # Check custom transformations if any. - check_opt(x, lambda: Linearizer(realized_ast), Device[Device.DEFAULT].to_program) class TestLinearizerOpts(unittest.TestCase): - def test_local_and_grouped_reduce(self): - if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local or not Device[Device.DEFAULT].linearizer_opts.has_shared: - self.skipTest("Only Compiled uses linearizer with locals and shared") + def test_local_and_grouped_reduce(self): + if ( + not isinstance(Device[Device.DEFAULT], Compiled) + or not Device[Device.DEFAULT].linearizer_opts.has_local + or not Device[Device.DEFAULT].linearizer_opts.has_shared + ): + self.skipTest("Only Compiled uses linearizer with locals and shared") - N = 128 - Tensor.manual_seed(1882) - a = Tensor.rand(4, 4, N, N) - b = Tensor.rand(4, 4, N) - r = (b.sqrt() + ((a+1).sum(axis=3).exp())) - helper_linearizer_opt(r, [ - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 0, 8)], - [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], - [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with locals + grouped reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with locals + grouped reduce + upcasts - ]) + N = 128 + Tensor.manual_seed(1882) + a = Tensor.rand(4, 4, N, N) + b = Tensor.rand(4, 4, N) + r = b.sqrt() + ((a + 1).sum(axis=3).exp()) + helper_linearizer_opt( + r, + [ + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 8)], + [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [ + Opt(OptOps.GROUPTOP, 0, 64) + ], # Checking how it works with grouped reduce + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], + [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], + [ + Opt(OptOps.LOCAL, 0, 2), + Opt(OptOps.GROUPTOP, 0, 64), + ], # Checking how it works with locals + grouped reduce + [ + Opt(OptOps.LOCAL, 0, 2), + Opt(OptOps.GROUPTOP, 0, 2), + Opt(OptOps.UPCAST, 0, 8), + Opt(OptOps.UNROLL, 1, 4), + ], # Checking how it works with locals + grouped reduce + upcasts + ], + ) - def test_upcasts(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + def test_upcasts(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - N = 16 - Tensor.manual_seed(1772) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts - ]) + N = 16 + Tensor.manual_seed(1772) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = (a + b).sqrt() * ((a + 1).exp()) + helper_linearizer_opt( + r, + [ + [Opt(OptOps.UPCAST, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts + ], + ) - def test_full_upcast(self): - if not isinstance(Device[Device.DEFAULT], Compiled): - self.skipTest("Only Compiled uses linearizer") + def test_full_upcast(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") - Tensor.manual_seed(1772) - a = Tensor.rand(4) - b = Tensor.rand(4) - r = (a+b).sqrt() * ((a+1).exp()) - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts - ]) + Tensor.manual_seed(1772) + a = Tensor.rand(4) + b = Tensor.rand(4) + r = (a + b).sqrt() * ((a + 1).exp()) + helper_linearizer_opt( + r, + [ + [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts + ], + ) - def test_matmul(self): - if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local or not Device[Device.DEFAULT].linearizer_opts.has_shared: - self.skipTest("Only Compiled uses linearizer with locals and shared") + def test_matmul(self): + if ( + not isinstance(Device[Device.DEFAULT], Compiled) + or not Device[Device.DEFAULT].linearizer_opts.has_local + or not Device[Device.DEFAULT].linearizer_opts.has_shared + ): + self.skipTest("Only Compiled uses linearizer with locals and shared") - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts - [Opt(OptOps.LOCAL, 0, 2)], - [Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], - [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals - [Opt(OptOps.GROUPTOP, 0, 2)], - [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)], # Checking all together - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], # Full global upcast + local - ]) + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a @ b + helper_linearizer_opt( + r, + [ + [Opt(OptOps.UPCAST, 0, 2)], + [ + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 4), + ], # Checking how it works with upcasts + [Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 1, 32)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], + [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], + [ + Opt(OptOps.LOCAL, 0, 16), + Opt(OptOps.LOCAL, 1, 8), + ], # Checking how it works with locals + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [ + Opt(OptOps.GROUPTOP, 0, 32), + Opt(OptOps.UNROLL, 0, 4), + ], # Checking how it works with grouped_reduce + [ + Opt(OptOps.LOCAL, 0, 2), + Opt(OptOps.LOCAL, 1, 2), + Opt(OptOps.GROUPTOP, 0, 32), + ], + [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 0, 8), + Opt(OptOps.GROUPTOP, 0, 4), + ], # Checking how it works with local+grouped_reduce + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.GROUPTOP, 0, 8), + Opt(OptOps.UNROLL, 0, 4), + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 2), + ], # Checking all together + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.GROUPTOP, 0, 8), + Opt(OptOps.UNROLL, 0, 4), + Opt(OptOps.UPCAST, 0, 8), + ], # Full global upcast + local + ], + ) - def test_double_reduce(self): - if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local or not Device[Device.DEFAULT].linearizer_opts.has_shared: - self.skipTest("Only Compiled uses linearizer with locals and shared") + def test_double_reduce(self): + if ( + not isinstance(Device[Device.DEFAULT], Compiled) + or not Device[Device.DEFAULT].linearizer_opts.has_local + or not Device[Device.DEFAULT].linearizer_opts.has_shared + ): + self.skipTest("Only Compiled uses linearizer with locals and shared") - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(8, N, 8, N) - r = a.sum(axis=(1,3)) - helper_linearizer_opt(r, [ - # openCL / GPU=1 is 256 max threads - [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], - [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. - [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. - [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 0, 2)], # No globals - ]) + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(8, N, 8, N) + r = a.sum(axis=(1, 3)) + helper_linearizer_opt( + r, + [ + # openCL / GPU=1 is 256 max threads + [Opt(OptOps.GROUPTOP, 0, 2)], + [Opt(OptOps.GROUPTOP, 0, 32)], + [Opt(OptOps.GROUPTOP, 1, 2)], + [ + Opt(OptOps.GROUPTOP, 1, 32) + ], # Checking how it works with 1 grouped_reduce. + [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], + [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], + [ + Opt(OptOps.GROUPTOP, 0, 4), + Opt(OptOps.GROUPTOP, 1, 64), + ], # Checking how it works with 2 grouped_reduces. + [ + Opt(OptOps.GROUPTOP, 0, 16), + Opt(OptOps.GROUPTOP, 1, 2), + Opt(OptOps.UNROLL, 0, 4), + ], + [ + Opt(OptOps.GROUPTOP, 0, 2), + Opt(OptOps.GROUPTOP, 1, 32), + Opt(OptOps.UNROLL, 2, 4), + ], # Checking how it works with 2 grouped_reduces + upcasts. + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 1, 4), + Opt(OptOps.GROUPTOP, 0, 4), + Opt(OptOps.GROUPTOP, 1, 4), + ], + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 1, 4), + Opt(OptOps.GROUPTOP, 0, 2), + Opt(OptOps.GROUPTOP, 1, 32), + Opt(OptOps.UNROLL, 1, 4), + ], # Checking how it works with 2 grouped_reduces + upcasts + locals. + [ + Opt(OptOps.LOCAL, 0, 2), + Opt(OptOps.LOCAL, 1, 2), + Opt(OptOps.GROUPTOP, 0, 8), + Opt(OptOps.GROUPTOP, 1, 4), + Opt(OptOps.UPCAST, 0, 2), + ], + [ + Opt(OptOps.LOCAL, 0, 2), + Opt(OptOps.LOCAL, 1, 2), + Opt(OptOps.GROUPTOP, 0, 8), + Opt(OptOps.GROUPTOP, 1, 4), + Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UNROLL, 0, 4), + Opt(OptOps.UNROLL, 1, 4), + ], # Checking how it works with 2 grouped_reduces + upcasts + locals. + [ + Opt(OptOps.LOCAL, 0, 4), + Opt(OptOps.LOCAL, 1, 4), + Opt(OptOps.GROUPTOP, 0, 4), + Opt(OptOps.GROUPTOP, 1, 4), + Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UPCAST, 0, 2), + ], # No globals + ], + ) - def test_tensor_core_opts(self): - if not isinstance(Device[Device.DEFAULT], Compiled) or not Device[Device.DEFAULT].linearizer_opts.has_local: - self.skipTest("Only Compiled uses linearizer with locals") - if Device.DEFAULT not in tensor_cores: - self.skipTest("No tensor cores for device") + def test_tensor_core_opts(self): + if ( + not isinstance(Device[Device.DEFAULT], Compiled) + or not Device[Device.DEFAULT].linearizer_opts.has_local + ): + self.skipTest("Only Compiled uses linearizer with locals") + if Device.DEFAULT not in tensor_cores: + self.skipTest("No tensor cores for device") - N = 128 - Tensor.manual_seed(1552) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - r = a@b - helper_linearizer_opt(r, [ - [Opt(OptOps.UPCAST, 0, 4)], - [Opt(OptOps.UPCAST, 1, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts - [Opt(OptOps.UNROLL, 0, 2)], # check last unroll - [Opt(OptOps.LASTLOCAL, 0, 4)], # check last local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)], - # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) - ], apply_tc=True) + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + r = a @ b + helper_linearizer_opt( + r, + [ + [Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.UPCAST, 1, 4)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts + [Opt(OptOps.UNROLL, 0, 2)], # check last unroll + [Opt(OptOps.LASTLOCAL, 0, 4)], # check last local + [ + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UNROLL, 0, 2), + ], # check combo of last unroll and last local + [ + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 4), + Opt(OptOps.UNROLL, 0, 2), + ], + [ + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 4), + Opt(OptOps.UNROLL, 0, 4), + ], + [ + Opt(OptOps.UPCAST, 0, 4), + Opt(OptOps.UPCAST, 1, 4), + Opt(OptOps.UNROLL, 0, 4), + Opt(OptOps.LASTLOCAL, 0, 2), + ], + # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) + ], + apply_tc=True, + ) - def test_padto_matmul(self): - if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer") - if Device.DEFAULT == "CUDA": self.skipTest("super slow on CUDA") - N = 17 * 17 - Tensor.manual_seed(289) - a = Tensor.rand(N, N) - b = Tensor.rand(N, N) - helper_linearizer_opt(a@b, [ - [Opt(OptOps.PADTO, 0, 32)], - [Opt(OptOps.PADTO, 1, 32)], - [Opt(OptOps.PADTO, 2, 32)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], - # can optimize further post PADTO - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4)], - ]) + def test_padto_matmul(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + if Device.DEFAULT == "CUDA": + self.skipTest("super slow on CUDA") + N = 17 * 17 + Tensor.manual_seed(289) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + helper_linearizer_opt( + a @ b, + [ + [Opt(OptOps.PADTO, 0, 32)], + [Opt(OptOps.PADTO, 1, 32)], + [Opt(OptOps.PADTO, 2, 32)], + [ + Opt(OptOps.PADTO, 0, 32), + Opt(OptOps.PADTO, 1, 32), + Opt(OptOps.PADTO, 2, 32), + ], + # can optimize further post PADTO + [ + Opt(OptOps.PADTO, 0, 32), + Opt(OptOps.PADTO, 1, 32), + Opt(OptOps.PADTO, 2, 32), + Opt(OptOps.UPCAST, 0, 2), + Opt(OptOps.UNROLL, 0, 4), + ], + ], + ) - def test_padto_max(self): - # pad uses invalid value 0, so max is not allowed - if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer") - N = 17 * 17 - a = -Tensor.ones(N, N) - with self.assertRaises(AssertionError): - helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) + def test_padto_max(self): + # pad uses invalid value 0, so max is not allowed + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + N = 17 * 17 + a = -Tensor.ones(N, N) + with self.assertRaises(AssertionError): + helper_linearizer_opt( + a.max(), + [ + [Opt(OptOps.PADTO, 0, 32)], + ], + ) - def test_padto_where(self): - # pad uses invalid value 0, so kernel with max is not allowed - if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer") - N = 17 * 17 - a = (Tensor.rand(N, N).max(axis=0) > 1).where(1, 0) - with self.assertRaises(AssertionError): - helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) + def test_padto_where(self): + # pad uses invalid value 0, so kernel with max is not allowed + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + N = 17 * 17 + a = (Tensor.rand(N, N).max(axis=0) > 1).where(1, 0) + with self.assertRaises(AssertionError): + helper_linearizer_opt( + a.max(), + [ + [Opt(OptOps.PADTO, 0, 32)], + ], + ) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 0ae18a7e5..023f3e4dd 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -6,87 +6,876 @@ from tinygrad.helpers import OSX, CI from test.external.fuzz_linearizer import run_linearizer # stuff needed to unpack a kernel -from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, get_lazyop_info +from tinygrad.ops import ( + LazyOp, + BinaryOps, + UnaryOps, + ReduceOps, + BufferOps, + MemBuffer, + ConstBuffer, + get_lazyop_info, +) from tinygrad.helpers import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -inf, nan = float('inf'), float('nan') + +inf, nan = float("inf"), float("nan") + def helper_test_lin(lin: Linearizer, opts, failed_platforms): - for opt in opts: - try: - lin.apply_opt(opt) - except AssertionError: - # it's considered fixed if we invalidated the opts - assert Device.DEFAULT not in failed_platforms - if Device.DEFAULT not in failed_platforms: - assert run_linearizer(lin) == "PASS" - else: - assert run_linearizer(lin) != "PASS" + for opt in opts: + try: + lin.apply_opt(opt) + except AssertionError: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms + if Device.DEFAULT not in failed_platforms: + assert run_linearizer(lin) == "PASS" + else: + assert run_linearizer(lin) != "PASS" + def helper_add_store(op): - info = get_lazyop_info(op) - return LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, info.dtype, ShapeTracker.from_shape(info.shape))) + info = get_lazyop_info(op) + return LazyOp( + BufferOps.STORE, + (op,), + MemBuffer(0, info.dtype, ShapeTracker.from_shape(info.shape)), + ) -@unittest.skipIf(CI and Device.DEFAULT=="CUDA", "failed on CUDA CI") + +@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "failed on CUDA CI") class TestLinearizerFailures(unittest.TestCase): - def test_failure_1(self): - ast = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)))),), arg=(32, 16, 1)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None) - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), [], failed_platforms=["CLANG"]) + def test_failure_1(self): + ast = LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(32, 16, 16), + strides=(16, 1, 0), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=(32, 16, 1), + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=2, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(32, 16, 1), + strides=(0, 1, 0), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(32, 16, 1), + strides=(16, 1, 0), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=None, + ) + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), [], failed_platforms=["CLANG"]) - def test_failure_2(self): - ast = LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))))),), arg=(32, 2, 37, 9, 1, 1)) - opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["CPU", "TORCH"]) + def test_failure_2(self): + ast = LazyOp( + op=ReduceOps.MAX, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(32, 2, 111, 27), + strides=(6160, 3080, 28, 1), + offset=0, + mask=((0, 32), (0, 2), (0, 110), (0, 27)), + contiguous=False, + ), + View( + shape=(32, 2, 37, 9, 2, 2), + strides=(5994, 2997, 81, 3, 27, 1), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=(32, 2, 37, 9, 1, 1), + ) + opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=["CPU", "TORCH"]) - @unittest.skipIf(CI and Device.DEFAULT=="METAL", "behaves differently on METAL CI") - def test_failure_3(self): - ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] - # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) + @unittest.skipIf( + CI and Device.DEFAULT == "METAL", "behaves differently on METAL CI" + ) + def test_failure_3(self): + ast = LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(32, 8, 16, 16), + strides=(2048, 256, 16, 1), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=(32, 8, 16, 1), + ) + opts = [ + Opt(op=OptOps.GROUP, axis=0, amt=4), + Opt(op=OptOps.UPCAST, axis=0, amt=4), + Opt(op=OptOps.UPCAST, axis=0, amt=2), + Opt(op=OptOps.UNROLL, axis=1, amt=0), + Opt(op=OptOps.UPCAST, axis=0, amt=4), + Opt(op=OptOps.LOCAL, axis=0, amt=2), + Opt(op=OptOps.LOCAL, axis=0, amt=2), + Opt(op=OptOps.UPCAST, axis=1, amt=0), + Opt(op=OptOps.LOCAL, axis=0, amt=32), + ] + # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} + ast = helper_add_store(ast) + helper_test_lin( + Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"] + ) - def test_failure_4(self): - ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 4, 1, 12, 2, 29), strides=(0, 0, 0, 2, 0, 216, 1, 8), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 1), (0, 11), (0, 2), (0, 27)), contiguous=False), View(shape=(1, 1, 1, 4, 22, 84), strides=(0, 0, 0, 696, 58, 1), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 12), (0, 58)), contiguous=False), View(shape=(1, 1, 1, 4, 2, 11, 3, 28), strides=(0, 0, 0, 1848, 924, 84, 28, 1), offset=0, mask=None, contiguous=True))))),), arg=(1, 1, 1, 4, 1, 11, 1, 28)) - opts = [Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] - # related to OptOps.NOLOCALS - # IndexError: list index out of range - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL"]) + def test_failure_4(self): + ast = LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(1, 1, 1, 4, 1, 12, 2, 29), + strides=(0, 0, 0, 2, 0, 216, 1, 8), + offset=0, + mask=( + (0, 1), + (0, 1), + (0, 1), + (0, 4), + (0, 1), + (0, 11), + (0, 2), + (0, 27), + ), + contiguous=False, + ), + View( + shape=(1, 1, 1, 4, 22, 84), + strides=(0, 0, 0, 696, 58, 1), + offset=0, + mask=( + (0, 1), + (0, 1), + (0, 1), + (0, 4), + (0, 12), + (0, 58), + ), + contiguous=False, + ), + View( + shape=(1, 1, 1, 4, 2, 11, 3, 28), + strides=(0, 0, 0, 1848, 924, 84, 28, 1), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=(1, 1, 1, 4, 1, 11, 1, 28), + ) + opts = [ + Opt(op=OptOps.LOCAL, axis=2, amt=4), + Opt(op=OptOps.UPCAST, axis=0, amt=2), + Opt(op=OptOps.UPCAST, axis=0, amt=0), + Opt(op=OptOps.LOCAL, axis=2, amt=2), + Opt(op=OptOps.UPCAST, axis=3, amt=0), + Opt(op=OptOps.UPCAST, axis=2, amt=0), + Opt(op=OptOps.UNROLL, axis=0, amt=0), + Opt(op=OptOps.UPCAST, axis=1, amt=0), + Opt(op=OptOps.NOLOCALS, axis=None, amt=None), + ] + # related to OptOps.NOLOCALS + # IndexError: list index out of range + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL"]) - def test_failure_5(self): - ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(1, 1, 1, 1, 1, 1, 1, 1)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] - # EXEC_ERROR, it has no global_size - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) + def test_failure_5(self): + ast = LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BinaryOps.MUL, + src=( + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=0.1464405059814453, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 2, + 1, + 4, + 1, + 3, + 1, + 4, + 1, + ), + strides=( + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=1.0, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 2, + 1, + 4, + 1, + 3, + 1, + 4, + 1, + ), + strides=( + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(2, 1, 4, 1, 3, 1, 4, 1), + strides=(0, 0, 0, 0, 0, 0, 0, 0), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BinaryOps.MUL, + src=( + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=0.1464405059814453, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 2, + 1, + 4, + 1, + 3, + 1, + 4, + 1, + ), + strides=( + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=1.0, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 2, + 1, + 4, + 1, + 3, + 1, + 4, + 1, + ), + strides=( + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(2, 1, 4, 1, 3, 1, 4, 1), + strides=(0, 0, 0, 0, 0, 0, 0, 0), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ), + ), + arg=None, + ), + ), + arg=(1, 1, 1, 1, 1, 1, 1, 1), + ) + opts = [ + Opt(op=OptOps.UNROLL, axis=0, amt=4), + Opt(op=OptOps.UNROLL, axis=0, amt=0), + ] + # EXEC_ERROR, it has no global_size + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) - def test_failure_6(self): - ast = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1.0, dtype=dtypes.int32, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))),), arg=(10, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=10.0, dtype=dtypes.int32, st=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)] - # COMPILE FAILED, KeyError: UOps.CONST - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) + def test_failure_6(self): + ast = LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=-1.0, + dtype=dtypes.int32, + st=ShapeTracker( + views=( + View( + shape=(11, 19), + strides=(0, 0), + offset=0, + mask=((0, 11), (9, 19)), + contiguous=False, + ), + View( + shape=(10, 10), + strides=(1, 20), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=(10, 1), + ), + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=10.0, + dtype=dtypes.int32, + st=ShapeTracker( + views=( + View( + shape=(10, 1), + strides=(0, 0), + offset=0, + mask=None, + contiguous=False, + ), + ) + ), + ), + ), + ), + arg=None, + ) + opts = [ + Opt(op=OptOps.UPCAST, axis=0, amt=2), + Opt(op=OptOps.UPCAST, axis=0, amt=0), + ] + # COMPILE FAILED, KeyError: UOps.CONST + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=[]) - @unittest.skipIf(Device.DEFAULT=="LLVM", "Segmentation fault") - def test_failure_7(self): - ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))))),), arg=(512, 32, 1, 34, 1, 34)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] - # test/test_linearizer_failures.py Fatal Python error: Segmentation fault - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"]) + @unittest.skipIf(Device.DEFAULT == "LLVM", "Segmentation fault") + def test_failure_7(self): + ast = LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(512, 32, 6, 8, 4, 6, 8, 4), + strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), + offset=0, + mask=( + (0, 512), + (0, 32), + (0, 6), + (0, 8), + (0, 1), + (0, 6), + (0, 8), + (0, 1), + ), + contiguous=False, + ), + View( + shape=(512, 32, 6, 35, 6, 35), + strides=(1179648, 36864, 6144, 192, 32, 1), + offset=0, + mask=( + (0, 512), + (0, 32), + (0, 6), + (0, 32), + (0, 6), + (0, 32), + ), + contiguous=False, + ), + View( + shape=(512, 32, 238, 238), + strides=(1411200, 44100, 210, 1), + offset=0, + mask=((0, 512), (0, 32), (0, 210), (0, 210)), + contiguous=False, + ), + View( + shape=(512, 32, 7, 34, 7, 34), + strides=(1812608, 56644, 8092, 238, 34, 1), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=(512, 32, 1, 34, 1, 34), + ) + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] + # test/test_linearizer_failures.py Fatal Python error: Segmentation fault + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=["LLVM"]) - @unittest.skipIf((Device.DEFAULT=="LLVM" and not OSX) or (Device.DEFAULT == "GPU" and CI), "Segmentation fault on ubuntu, GPU requires cl_khr_fp16") - def test_failure_8(self): - ast = LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.DIV, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=(1, 1, 1)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.000244140625, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1e-06, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None)), arg=None),), arg=None) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)] - # fatal error: bracket nesting level exceeded maximum of 256 - # note: use -fbracket-depth=N to increase maximum nesting level - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["CLANG", "METAL"]) + @unittest.skipIf( + (Device.DEFAULT == "LLVM" and not OSX) or (Device.DEFAULT == "GPU" and CI), + "Segmentation fault on ubuntu, GPU requires cl_khr_fp16", + ) + def test_failure_8(self): + ast = LazyOp( + op=UnaryOps.SQRT, + src=( + LazyOp( + op=BinaryOps.DIV, + src=( + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=1.0, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(1, 1, 1), + strides=(0, 0, 0), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BinaryOps.MUL, + src=( + LazyOp( + op=ReduceOps.SUM, + src=( + LazyOp( + op=BinaryOps.MUL, + src=( + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.half, + st=ShapeTracker( + views=( + View( + shape=( + 1, + 1, + 4096, + ), + strides=( + 0, + 0, + 1, + ), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=2, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 1, + 1, + 4096, + ), + strides=( + 0, + 0, + 1, + ), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BinaryOps.ADD, + src=( + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=1, + dtype=dtypes.half, + st=ShapeTracker( + views=( + View( + shape=( + 1, + 1, + 4096, + ), + strides=( + 0, + 0, + 1, + ), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + LazyOp( + op=BufferOps.LOAD, + src=(), + arg=MemBuffer( + idx=2, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=( + 1, + 1, + 4096, + ), + strides=( + 0, + 0, + 1, + ), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=None, + ), + ), + arg=None, + ), + ), + arg=(1, 1, 1), + ), + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=0.000244140625, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(1, 1, 1), + strides=(0, 0, 0), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=None, + ), + LazyOp( + op=BufferOps.CONST, + src=(), + arg=ConstBuffer( + val=1e-06, + dtype=dtypes.float, + st=ShapeTracker( + views=( + View( + shape=(1, 1, 1), + strides=(0, 0, 0), + offset=0, + mask=None, + contiguous=True, + ), + ) + ), + ), + ), + ), + arg=None, + ), + ), + arg=None, + ), + ), + arg=None, + ) + opts = [ + Opt(op=OptOps.UNROLL, axis=0, amt=4), + Opt(op=OptOps.UNROLL, axis=0, amt=4), + Opt(op=OptOps.UNROLL, axis=0, amt=4), + Opt(op=OptOps.UNROLL, axis=0, amt=4), + ] + # fatal error: bracket nesting level exceeded maximum of 256 + # note: use -fbracket-depth=N to increase maximum nesting level + ast = helper_add_store(ast) + helper_test_lin(Linearizer(ast), opts, failed_platforms=["CLANG", "METAL"]) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_net_speed.py b/test/test_net_speed.py index bd513e28d..b2399a745 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -6,84 +6,94 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import Profiling import pytest -pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang] +pytestmark = [ + pytest.mark.exclude_cuda, + pytest.mark.exclude_gpu, + pytest.mark.exclude_clang, +] + class TestConvSpeed(unittest.TestCase): + def test_mnist(self): + # https://keras.io/examples/vision/mnist_convnet/ + conv = 3 + inter_chan, out_chan = 32, 64 - def test_mnist(self): - # https://keras.io/examples/vision/mnist_convnet/ - conv = 3 - inter_chan, out_chan = 32, 64 + # ****** torch baseline ******* - # ****** torch baseline ******* + torch.backends.mkldnn.enabled = False - torch.backends.mkldnn.enabled = False + conv = 3 + inter_chan, out_chan = 32, 64 + c1 = torch.randn(inter_chan, 1, conv, conv, requires_grad=True) + c2 = torch.randn(out_chan, inter_chan, conv, conv, requires_grad=True) + l1 = torch.randn(out_chan * 5 * 5, 10, requires_grad=True) - conv = 3 - inter_chan, out_chan = 32, 64 - c1 = torch.randn(inter_chan,1,conv,conv, requires_grad=True) - c2 = torch.randn(out_chan,inter_chan,conv,conv, requires_grad=True) - l1 = torch.randn(out_chan*5*5, 10, requires_grad=True) + c2d = torch.nn.functional.conv2d + mp = torch.nn.MaxPool2d((2, 2)) + lsm = torch.nn.LogSoftmax(dim=1) - c2d = torch.nn.functional.conv2d - mp = torch.nn.MaxPool2d((2,2)) - lsm = torch.nn.LogSoftmax(dim=1) + cnt = 5 + fpt, bpt = 0.0, 0.0 + for i in range(cnt): + et0 = time.time() + x = torch.randn(128, 1, 28, 28, requires_grad=True) + x = mp(c2d(x, c1).relu()) + x = mp(c2d(x, c2).relu()) + x = x.reshape(x.shape[0], -1) + out = lsm(x.matmul(l1)) + out = out.mean() + et1 = time.time() + out.backward() + et2 = time.time() + fpt += et1 - et0 + bpt += et2 - et1 - cnt = 5 - fpt, bpt = 0.0, 0.0 - for i in range(cnt): - et0 = time.time() - x = torch.randn(128, 1, 28, 28, requires_grad=True) - x = mp(c2d(x,c1).relu()) - x = mp(c2d(x,c2).relu()) - x = x.reshape(x.shape[0], -1) - out = lsm(x.matmul(l1)) - out = out.mean() - et1 = time.time() - out.backward() - et2 = time.time() - fpt += (et1-et0) - bpt += (et2-et1) + fpt_baseline = fpt * 1000 / cnt + bpt_baseline = bpt * 1000 / cnt + print("torch forward pass: %.3f ms" % fpt_baseline) + print("torch backward pass: %.3f ms" % bpt_baseline) - fpt_baseline = (fpt*1000/cnt) - bpt_baseline = (bpt*1000/cnt) - print("torch forward pass: %.3f ms" % fpt_baseline) - print("torch backward pass: %.3f ms" % bpt_baseline) + # ****** tinygrad compare ******* - # ****** tinygrad compare ******* + c1 = Tensor(c1.detach().numpy(), requires_grad=True) + c2 = Tensor(c2.detach().numpy(), requires_grad=True) + l1 = Tensor(l1.detach().numpy(), requires_grad=True) - c1 = Tensor(c1.detach().numpy(), requires_grad=True) - c2 = Tensor(c2.detach().numpy(), requires_grad=True) - l1 = Tensor(l1.detach().numpy(), requires_grad=True) + cnt = 5 + fpt, bpt = 0.0, 0.0 + for i in range(1 + cnt): + et0 = time.time() + x = Tensor.randn(128, 1, 28, 28) + x = x.conv2d(c1).relu().avg_pool2d() + x = x.conv2d(c2).relu().max_pool2d() + x = x.reshape(shape=(x.shape[0], -1)) + out = x.dot(l1).log_softmax() + out = out.mean() + out.realize() + et1 = time.time() + out.backward() + [x.grad.realize() for x in [c1, c2, l1]] + et2 = time.time() + if i == 0: + pr = Profiling(sort="time", frac=0.2) + pr.__enter__() + else: + fpt += et1 - et0 + bpt += et2 - et1 - cnt = 5 - fpt, bpt = 0.0, 0.0 - for i in range(1+cnt): - et0 = time.time() - x = Tensor.randn(128, 1, 28, 28) - x = x.conv2d(c1).relu().avg_pool2d() - x = x.conv2d(c2).relu().max_pool2d() - x = x.reshape(shape=(x.shape[0], -1)) - out = x.dot(l1).log_softmax() - out = out.mean() - out.realize() - et1 = time.time() - out.backward() - [x.grad.realize() for x in [c1, c2, l1]] - et2 = time.time() - if i == 0: - pr = Profiling(sort='time', frac=0.2) - pr.__enter__() - else: - fpt += (et1-et0) - bpt += (et2-et1) - - pr.__exit__() - fpt = (fpt*1000/cnt) - bpt = (bpt*1000/cnt) - print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline)) - print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline)) + pr.__exit__() + fpt = fpt * 1000 / cnt + bpt = bpt * 1000 / cnt + print( + "forward pass: %.3f ms, %.2fx off baseline %.3f ms" + % (fpt, fpt / fpt_baseline, fpt_baseline) + ) + print( + "backward pass: %.3f ms, %.2fx off baseline %.3f ms" + % (bpt, bpt / bpt_baseline, bpt_baseline) + ) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/test_nn.py b/test/test_nn.py index 5804f84c7..b4e0d40e2 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4,334 +4,440 @@ import numpy as np from tinygrad.helpers import CI from tinygrad.jit import TinyJit from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm +from tinygrad.nn import ( + BatchNorm2d, + Conv1d, + ConvTranspose1d, + Conv2d, + ConvTranspose2d, + Linear, + GroupNorm, + LayerNorm, + LayerNorm2d, + Embedding, + InstanceNorm, +) import torch import pytest pytestmark = [pytest.mark.exclude_cuda] + class TestNN(unittest.TestCase): - def test_sparse_cat_cross_entropy(self): - input = torch.randn(3, 5) - target = torch.empty(3, dtype=torch.long).random_(5) - loss_fun = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fun(input, target) + def test_sparse_cat_cross_entropy(self): + input = torch.randn(3, 5) + target = torch.empty(3, dtype=torch.long).random_(5) + loss_fun = torch.nn.CrossEntropyLoss(reduction="mean") + loss = loss_fun(input, target) - input_tiny = Tensor(input.detach().numpy()) - target_tiny = Tensor(target.detach().numpy()) - loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny) + input_tiny = Tensor(input.detach().numpy()) + target_tiny = Tensor(target.detach().numpy()) + loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny) - np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6) + np.testing.assert_allclose( + loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6 + ) - def test_batchnorm2d(self, training=False): - szs = [4, 8, 16, 32] - for sz in szs: - # create in tinygrad - Tensor.training = training - bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training) - bn.weight = Tensor.randn(sz) - bn.bias = Tensor.randn(sz) - bn.running_mean = Tensor.randn(sz) - bn.running_var = Tensor.randn(sz) - bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0 + def test_batchnorm2d(self, training=False): + szs = [4, 8, 16, 32] + for sz in szs: + # create in tinygrad + Tensor.training = training + bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training) + bn.weight = Tensor.randn(sz) + bn.bias = Tensor.randn(sz) + bn.running_mean = Tensor.randn(sz) + bn.running_var = Tensor.randn(sz) + bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0 - # create in torch - with torch.no_grad(): - tbn = torch.nn.BatchNorm2d(sz).eval() - tbn.training = training - tbn.weight[:] = torch.tensor(bn.weight.numpy()) - tbn.bias[:] = torch.tensor(bn.bias.numpy()) - tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy()) - tbn.running_var[:] = torch.tensor(bn.running_var.numpy()) + # create in torch + with torch.no_grad(): + tbn = torch.nn.BatchNorm2d(sz).eval() + tbn.training = training + tbn.weight[:] = torch.tensor(bn.weight.numpy()) + tbn.bias[:] = torch.tensor(bn.bias.numpy()) + tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy()) + tbn.running_var[:] = torch.tensor(bn.running_var.numpy()) - np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) - np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose( + bn.running_mean.numpy(), + tbn.running_mean.detach().numpy(), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + bn.running_var.numpy(), + tbn.running_var.detach().numpy(), + rtol=1e-5, + atol=1e-6, + ) - # trial - inn = Tensor.randn(2, sz, 3, 3) + # trial + inn = Tensor.randn(2, sz, 3, 3) - # in tinygrad - outt = bn(inn) + # in tinygrad + outt = bn(inn) - # in torch - toutt = tbn(torch.tensor(inn.numpy())) + # in torch + toutt = tbn(torch.tensor(inn.numpy())) - # close - np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6) + # close + np.testing.assert_allclose( + outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6 + ) - np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose( + bn.running_mean.numpy(), + tbn.running_mean.detach().numpy(), + rtol=1e-5, + atol=1e-6, + ) - np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose( + bn.running_var.numpy(), + tbn.running_var.detach().numpy(), + rtol=1e-5, + atol=1e-6, + ) - def test_batchnorm2d_training(self): - self.test_batchnorm2d(True) + def test_batchnorm2d_training(self): + self.test_batchnorm2d(True) - def test_linear(self): - def _test_linear(x): + def test_linear(self): + def _test_linear(x): + # create in tinygrad + model = Linear(in_dim, out_dim) + z = model(x) - # create in tinygrad - model = Linear(in_dim, out_dim) - z = model(x) + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.Linear(in_dim, out_dim).eval() + torch_layer.weight[:] = torch.tensor( + model.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor( + model.bias.numpy(), dtype=torch.float32 + ) + torch_x = torch.tensor(x.numpy(), dtype=torch.float32) + torch_z = torch_layer(torch_x) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.Linear(in_dim, out_dim).eval() - torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32) - torch_x = torch.tensor(x.numpy(), dtype=torch.float32) + # test + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) + + BS, T, in_dim, out_dim = 4, 2, 8, 16 + _test_linear(Tensor.randn(BS, in_dim)) + _test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims + + def test_conv1d(self): + BS, C1, W = 4, 16, 224 // 4 + C2, K, S, P = 64, 7, 2, 1 + + # create in tinygrad + layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P) + + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.Conv1d( + C1, C2, kernel_size=K, stride=S, padding=P + ).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + + # test + x = Tensor.uniform(BS, C1, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) - # test - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + def test_conv2d(self): + BS, C1, H, W = 4, 16, 224 // 4, 224 // 4 + C2, K, S, P = 64, 7, 2, 1 - BS, T, in_dim, out_dim = 4, 2, 8, 16 - _test_linear(Tensor.randn(BS, in_dim)) - _test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims + # create in tinygrad + layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) - def test_conv1d(self): - BS, C1, W = 4, 16, 224//4 - C2, K, S, P = 64, 7, 2, 1 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.Conv2d( + C1, C2, kernel_size=K, stride=S, padding=P + ).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P) + # test + x = Tensor.uniform(BS, C1, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + @unittest.skipIf( + Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends" + ) + def test_conv2d_winograd(self): + BS, C1, H, W = 2, 8, 16, 16 + C2, K, S, P = 8, 3, 1, 1 - # test - x = Tensor.uniform(BS, C1, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + old_wino = Tensor.wino + Tensor.wino = True - def test_conv2d(self): - BS, C1, H, W = 4, 16, 224//4, 224//4 - C2, K, S, P = 64, 7, 2, 1 + # create in tinygrad + layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) + layer.weight.requires_grad = True + layer.bias.requires_grad = True - # create in tinygrad - layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) + # create in torch + torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval() + torch_layer.weight = torch.nn.Parameter( + torch.tensor(layer.weight.numpy(), dtype=torch.float32) + ) + torch_layer.bias = torch.nn.Parameter( + torch.tensor(layer.bias.numpy(), dtype=torch.float32) + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + # test + x = Tensor.uniform(BS, C1, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) - # test - x = Tensor.uniform(BS, C1, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + m = z.mean() + m.backward() + gw = layer.weight.grad.realize() + gb = layer.bias.grad.realize() + gx = x.grad.realize() - @unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends") - def test_conv2d_winograd(self): - BS, C1, H, W = 2, 8, 16, 16 - C2, K, S, P = 8, 3, 1, 1 + torch_z.mean().backward() + np.testing.assert_allclose( + gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5 + ) + np.testing.assert_allclose( + gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5 + ) + np.testing.assert_allclose( + gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5 + ) - old_wino = Tensor.wino - Tensor.wino = True + Tensor.wino = old_wino - # create in tinygrad - layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) - layer.weight.requires_grad = True - layer.bias.requires_grad = True + @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") + def test_conv_transpose1d(self): + BS, C1, W = 4, 16, 224 // 4 + C2, K, S, P = 64, 7, 2, 1 - # create in torch - torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval() - torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32)) - torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32)) + # create in tinygrad + layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P) - # test - x = Tensor.uniform(BS, C1, H, W, requires_grad=True) - z = layer(x) - torch_x = torch.tensor(x.numpy(), requires_grad=True) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.ConvTranspose1d( + C1, C2, kernel_size=K, stride=S, padding=P + ).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - m = z.mean() - m.backward() - gw = layer.weight.grad.realize() - gb = layer.bias.grad.realize() - gx = x.grad.realize() + # test + x = Tensor.uniform(BS, C1, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) - torch_z.mean().backward() - np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5) - np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5) - np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5) + @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") + def test_conv_transpose2d(self): + BS, C1, H, W = 4, 16, 224 // 4, 224 // 4 + C2, K, S, P = 64, 7, 2, 1 - Tensor.wino = old_wino + # create in tinygrad + layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P) - @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") - def test_conv_transpose1d(self): - BS, C1, W = 4, 16, 224//4 - C2, K, S, P = 64, 7, 2, 1 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.ConvTranspose2d( + C1, C2, kernel_size=K, stride=S, padding=P + ).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P) + # test + x = Tensor.uniform(BS, C1, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_groupnorm(self): + BS, H, W, C, G = 20, 10, 10, 6, 3 - # test - x = Tensor.uniform(BS, C1, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + # create in tinygrad + layer = GroupNorm(G, C) - @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") - def test_conv_transpose2d(self): - BS, C1, H, W = 4, 16, 224//4, 224//4 - C2, K, S, P = 64, 7, 2, 1 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.GroupNorm(G, C).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P) + # test + x = Tensor.randn(BS, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_layernorm(self): + N, C, H, W = 20, 5, 10, 10 - # test - x = Tensor.uniform(BS, C1, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) + # create in tinygrad + layer = LayerNorm([H, W]) - def test_groupnorm(self): - BS, H, W, C, G = 20, 10, 10, 6, 3 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.LayerNorm([H, W]).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = GroupNorm(G, C) + # test + x = Tensor.randn(N, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.GroupNorm(G, C).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_layernorm_2d(self): + N, C, H, W = 20, 5, 10, 10 - # test - x = Tensor.randn(BS, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # create in tinygrad + layer = LayerNorm2d(C) - def test_layernorm(self): - N, C, H, W = 20, 5, 10, 10 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.LayerNorm([C]).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = LayerNorm([H, W]) + # test + x = Tensor.randn(N, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([H, W]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_instancenorm_2d(self): + N, C, H, W = 20, 5, 10, 10 - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # create in tinygrad + layer = InstanceNorm(C) - def test_layernorm_2d(self): - N, C, H, W = 20, 5, 10, 10 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = LayerNorm2d(C) + # test + x = Tensor.randn(N, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([C]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_instancenorm_3d(self): + N, C, D, H, W = 20, 5, 3, 10, 10 - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # create in tinygrad + layer = InstanceNorm(C) - def test_instancenorm_2d(self): - N, C, H, W = 20, 5, 10, 10 + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) - # create in tinygrad - layer = InstanceNorm(C) + # test + x = Tensor.randn(N, C, D, H, W) + z = layer(x) + torch_x = torch.tensor(x.numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + def test_embedding(self): + B, T, C, VS = 4, 10, 20, 28 - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # create in tinygrad + layer = Embedding(VS, C) - def test_instancenorm_3d(self): - N, C, D, H, W = 20, 5, 3, 10, 10 + with torch.no_grad(): + torch_layer = torch.nn.Embedding(VS, C).eval() + torch_layer.weight[:] = torch.tensor( + layer.weight.numpy(), dtype=torch.float32 + ) - # create in tinygrad - layer = InstanceNorm(C) + # test + x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) + z = layer(x) + torch_x = torch.tensor(x.numpy().astype(np.int32)) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8 + ) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + # test with jit enabled + @TinyJit + def layer_jit(x): + return layer(x).realize() - # test - x = Tensor.randn(N, C, D, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) - - def test_embedding(self): - B, T, C, VS = 4, 10, 20, 28 - - # create in tinygrad - layer = Embedding(VS, C) - - with torch.no_grad(): - torch_layer = torch.nn.Embedding(VS, C).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - - # test - x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) - z = layer(x) - torch_x = torch.tensor(x.numpy().astype(np.int32)) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) - - # test with jit enabled - @TinyJit - def layer_jit(x): - return layer(x).realize() - - for _ in range(3): - x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) - z = layer_jit(x) - torch_x = torch.tensor(x.numpy().astype(np.int32)) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) + for _ in range(3): + x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32)) + z = layer_jit(x) + torch_x = torch.tensor(x.numpy().astype(np.int32)) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose( + z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8 + ) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 179761b15..7652b610c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -8,1296 +8,2790 @@ from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes from tinygrad import Device if CI: - import warnings - warnings.filterwarnings("ignore", message="Non-empty compiler output encountered") + import warnings + + warnings.filterwarnings("ignore", message="Non-empty compiler output encountered") FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) -def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3): - if tinygrad_fxn is None: tinygrad_fxn = torch_fxn - ts, tst = prepare_test_op(a, b, shps, vals, forward_only) - st = time.monotonic() - out = torch_fxn(*ts) - torch_fp = time.monotonic() - st - st = time.monotonic() - ret = tinygrad_fxn(*tst).realize() - tinygrad_fp = time.monotonic() - st - - def compare(s, x,y,atol,rtol): - if PRINT_TENSORS: print(s, x, y) - assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}" - try: - np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) - except Exception: - raise Exception(f"{s} failed shape {x.shape}") - - if DEBUG >= 6: - np.set_printoptions(linewidth=200, suppress=True) - print(ret.numpy()) - print(out.detach().numpy()) - compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol) - - torch_fbp, tinygrad_fbp = np.nan, np.nan - if not forward_only and not FORWARD_ONLY: - st = time.monotonic() - (out+1).square().mean().backward() - torch_fbp = time.monotonic() - st +def helper_test_op( + shps, + torch_fxn, + tinygrad_fxn=None, + atol=1e-6, + rtol=1e-3, + grad_atol=1e-4, + grad_rtol=1e-3, + forward_only=False, + vals=None, + a=-0.5, + b=3, +): + if tinygrad_fxn is None: + tinygrad_fxn = torch_fxn + ts, tst = prepare_test_op(a, b, shps, vals, forward_only) st = time.monotonic() - (ret+1).square().mean().backward() - for tt in tst: tt.grad.realize() - tinygrad_fbp = time.monotonic() - st + out = torch_fxn(*ts) + torch_fp = time.monotonic() - st - for i, (t, tt) in enumerate(zip(ts, tst)): - compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol) + st = time.monotonic() + ret = tinygrad_fxn(*tst).realize() + tinygrad_fp = time.monotonic() - st + + def compare(s, x, y, atol, rtol): + if PRINT_TENSORS: + print(s, x, y) + assert ( + x.shape == y.shape + ), f"shape mismatch: tinygrad={x.shape} | torch={y.shape}" + try: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) + except Exception: + raise Exception(f"{s} failed shape {x.shape}") + + if DEBUG >= 6: + np.set_printoptions(linewidth=200, suppress=True) + print(ret.numpy()) + print(out.detach().numpy()) + compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol) + + torch_fbp, tinygrad_fbp = np.nan, np.nan + if not forward_only and not FORWARD_ONLY: + st = time.monotonic() + (out + 1).square().mean().backward() + torch_fbp = time.monotonic() - st + + st = time.monotonic() + (ret + 1).square().mean().backward() + for tt in tst: + tt.grad.realize() + tinygrad_fbp = time.monotonic() - st + + for i, (t, tt) in enumerate(zip(ts, tst)): + compare( + f"backward pass tensor {i}", + tt.grad.numpy(), + t.grad.detach().numpy(), + atol=grad_atol, + rtol=grad_rtol, + ) + + if not CI: + print( + "\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " + % ( + shps, + torch_fp * 1000, + tinygrad_fp * 1000, + torch_fbp * 1000, + tinygrad_fbp * 1000, + ), + end="", + ) - if not CI: print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") def prepare_test_op(a, b, shps, vals, forward_only=False): - torch.manual_seed(0) - np.random.seed(0) - if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] - else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=(not forward_only), dtype=torch.float32) for x in shps] - tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] - return ts, tst + torch.manual_seed(0) + np.random.seed(0) + if shps is None: + ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] + else: + ts = [ + torch.tensor( + (np.random.random(size=x) + a) * b, + requires_grad=(not forward_only), + dtype=torch.float32, + ) + for x in shps + ] + tst = [ + Tensor( + x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY) + ) + for x in ts + ] + return ts, tst + class TestOps(unittest.TestCase): - - def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, a=-0.5, b=3): - if getenv("CUDACPU"): self.skipTest('helper_test_exception fails in CUDACPU') - ts, tst = prepare_test_op(a, b, shps, vals) - with self.assertRaises(expected) as torch_cm: - torch_fxn(*ts) - with self.assertRaises(expected) as tinygrad_cm: - tinygrad_fxn(*tst) - if exact: self.assertEqual(str(torch_cm.exception), str(tinygrad_cm.exception)) - if not CI: print("\ntesting %40r torch/tinygrad exception: %s / %s" % (shps, torch_cm.exception, tinygrad_cm.exception), end="") - - def test_full_like(self): - a = Tensor([[1,2,3],[4,5,6]]) - b = torch.tensor([[1,2,3],[4,5,6]]) - helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True) - def test_full(self): - helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True) - def test_zeros(self): - helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) - helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True) - helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True) - def test_zeros_like(self): - a = Tensor([[1,2,3],[4,5,6]]) - b = torch.tensor([[1,2,3],[4,5,6]]) - helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True) - def test_empty_0(self): - helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True) - def test_ones(self): - helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) - helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True) - helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True) - def test_ones_like(self): - a = Tensor([[1,2,3],[4,5,6]]) - b = torch.tensor([[1,2,3],[4,5,6]]) - helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) - def test_eye(self): - helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) - helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) - - def test_chunk(self): - tor = torch.arange(13).repeat(8, 1).chunk(6, 1) - ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1) - assert len(tor) == len(ten) - for i in range(len(tor)): - helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - - tor = torch.arange(13).repeat(8, 1).chunk(6, 0) - ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0) - assert len(tor) == len(ten) - for i in range(len(tor)): - helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - - tor = torch.arange(13).repeat(8, 1).chunk(3, -1) - ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1) - assert len(tor) == len(ten) - for i in range(len(tor)): - helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - - tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2) - ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2) - assert len(tor) == len(ten) - for i in range(len(tor)): - helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - - def test_arange(self): - helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) - helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True) - helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True) - helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True) - def test_arange_simple(self): - helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) - def test_arange_big(self): - helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True) - - def test_sum_fake(self): - helper_test_op([(256, 1)], lambda x: x.sum(axis=1)) - - def test_sum_collapse(self): - helper_test_op([], lambda: torch.ones(256,256).sum(axis=1), lambda: Tensor.ones(256,256).sum(axis=1), forward_only=True) - - def test_sum_collapse_neg(self): - helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True) - - def test_sum_pad_collapse(self): - helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True) - - # this is more complex and won't fold for a while - def test_sum_cat_collapse(self): - helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True) - - def test_max_dont_collapse(self): - helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True) - - def test_where(self): - helper_test_op( - [(100,)], - lambda x: torch.where(x > 0.5, 4, 2), - lambda x: (x > 0.5).where(4, 2), forward_only=True) - - for shps in [[(8,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]: - helper_test_op( + def helper_test_exception( + self, shps, - lambda x, a, b: torch.where(x > 0.5, a, b), - lambda x, a, b: (x > 0.5).where(a, b), forward_only=True) + torch_fxn, + tinygrad_fxn, + expected, + exact=False, + vals=None, + a=-0.5, + b=3, + ): + if getenv("CUDACPU"): + self.skipTest("helper_test_exception fails in CUDACPU") + ts, tst = prepare_test_op(a, b, shps, vals) + with self.assertRaises(expected) as torch_cm: + torch_fxn(*ts) + with self.assertRaises(expected) as tinygrad_cm: + tinygrad_fxn(*tst) + if exact: + self.assertEqual(str(torch_cm.exception), str(tinygrad_cm.exception)) + if not CI: + print( + "\ntesting %40r torch/tinygrad exception: %s / %s" + % (shps, torch_cm.exception, tinygrad_cm.exception), + end="", + ) - def test_where_permute(self): - helper_test_op( - [(5, 5)], - lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)), - lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True) + def test_full_like(self): + a = Tensor([[1, 2, 3], [4, 5, 6]]) + b = torch.tensor([[1, 2, 3], [4, 5, 6]]) + helper_test_op( + [], + lambda: torch.full_like(b, 4), + lambda: Tensor.full_like(a, 4), + forward_only=True, + ) - def _test_cmp(self, fxn, reverse=True): - for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: - helper_test_op(shps, fxn, fxn, forward_only=True) - helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) - helper_test_op(None, lambda x,y: fxn(x,2), lambda x,y: fxn(x,2), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) - helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]]) - if reverse: helper_test_op(None, lambda x,y: fxn(2,y), lambda x,y: fxn(2,y), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) + def test_full(self): + helper_test_op( + [], + lambda: torch.full((45, 65), 4), + lambda: Tensor.full((45, 65), 4), + forward_only=True, + ) - def test_cmp_eq(self): self._test_cmp(lambda x,y: x==y, reverse=False) - def test_cmp_gt(self): self._test_cmp(lambda x,y: x>y) - def test_cmp_ge(self): self._test_cmp(lambda x,y: x>=y) - def test_cmp_lt(self): self._test_cmp(lambda x,y: x0, "no 1d dot for images") - def test_dot_1d(self): - helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - def test_dot(self): - helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - with self.assertRaises(AssertionError): - a = Tensor(3.14) - a.matmul(a) + def test_eye(self): + helper_test_op( + [], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True + ) + helper_test_op( + [], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True + ) - def test_multinomial(self): - # NOTE: this is random, so it has a very large atol - helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.) + def test_chunk(self): + tor = torch.arange(13).repeat(8, 1).chunk(6, 1) + ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - def test_small_cumsum(self): - helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - def test_simple_cumsum(self): - helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - def test_cumsum(self): - helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), atol=1e-6) - helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2), atol=1e-6) - helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1), atol=1e-6) + tor = torch.arange(13).repeat(8, 1).chunk(6, 0) + ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - def test_argmax(self): - self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max - helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True) - def test_argmin(self): - self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy()) - helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True) + tor = torch.arange(13).repeat(8, 1).chunk(3, -1) + ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - def test_matmul_simple(self): - helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - def test_matmul(self): - helper_test_op([(64), (64,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2) + ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) - @unittest.skipIf(IMAGE>0, "no batched matmul on images") - def test_matmul_batched(self): - helper_test_op([(3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) + def test_arange(self): + helper_test_op( + [], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True + ) + helper_test_op( + [], + lambda: torch.arange(5, 10, 3), + lambda: Tensor.arange(5, 10, 3), + forward_only=True, + ) + helper_test_op( + [], + lambda: torch.arange(10, 5, -3), + lambda: Tensor.arange(10, 5, -3), + forward_only=True, + ) + helper_test_op( + [], + lambda: torch.arange(11, 5, -3), + lambda: Tensor.arange(11, 5, -3), + forward_only=True, + ) - @unittest.skipIf(IMAGE>0, "no batched matmul on images") - def test_matmul_batched_vector(self): - helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - def test_small_gemm(self): - helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3) - def test_small_gemm_eye(self): - helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) - def test_gemm(self): - helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) - def test_big_gemm(self): - helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) - def test_broadcastdot(self): - helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) - with self.assertRaises(AssertionError): - a = Tensor(3.14) - b = Tensor.ones(3,3) - a @ b - def test_multidot(self): - helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) - helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) - def test_sum_simple(self): - helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.,1.]]) - def test_sum_full(self): - helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum()) - def test_sum_small_full(self): - helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum) - def test_sum_relu(self): - helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu()) - def test_sum(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3)) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) - helper_test_op([()], lambda x: x.sum(), Tensor.sum) - def test_min(self): - helper_test_op([(3,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5)) - helper_test_op([()], lambda x: x.min(), Tensor.min) - def test_max(self): - helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) - helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) - helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), + def test_arange_simple(self): + helper_test_op( + [], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True + ) + + def test_arange_big(self): + helper_test_op( + [], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True + ) + + def test_sum_fake(self): + helper_test_op([(256, 1)], lambda x: x.sum(axis=1)) + + def test_sum_collapse(self): + helper_test_op( + [], + lambda: torch.ones(256, 256).sum(axis=1), + lambda: Tensor.ones(256, 256).sum(axis=1), + forward_only=True, + ) + + def test_sum_collapse_neg(self): + helper_test_op( + [], + lambda: (-torch.ones(3, 3)).sum(axis=1), + lambda: (-Tensor.ones(3, 3)).sum(axis=1), + forward_only=True, + ) + + def test_sum_pad_collapse(self): + helper_test_op( + [], + lambda: torch.nn.functional.pad( + torch.ones(256, 256), pad=(0, 64, 0, 0) + ).sum(axis=1), + lambda: Tensor.ones(256, 256).pad(((0, 0), (0, 64))).sum(axis=1), + forward_only=True, + ) + + # this is more complex and won't fold for a while + def test_sum_cat_collapse(self): + helper_test_op( + [], + lambda: torch.cat([torch.ones(256, 256), torch.zeros(256, 64)], dim=1).sum( + axis=1 + ), + lambda: Tensor.cat(Tensor.ones(256, 256), Tensor.zeros(256, 64), dim=1).sum( + axis=1 + ), + forward_only=True, + ) + + def test_max_dont_collapse(self): + helper_test_op( + [], + lambda: torch.ones(256, 256).max(1)[0], + lambda: Tensor.ones(256, 256).max(1), + forward_only=True, + ) + + def test_where(self): + helper_test_op( + [(100,)], + lambda x: torch.where(x > 0.5, 4, 2), + lambda x: (x > 0.5).where(4, 2), + forward_only=True, + ) + + for shps in [ + [(8,), (1,), (1,)], + [(10, 10), (10,), (10,)], + [(100,)] * 3, + [(10, 10)] * 3, + ]: + helper_test_op( + shps, + lambda x, a, b: torch.where(x > 0.5, a, b), + lambda x, a, b: (x > 0.5).where(a, b), + forward_only=True, + ) + + def test_where_permute(self): + helper_test_op( + [(5, 5)], + lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)), + lambda x: (x > 0.5).where(4, 2).permute((1, 0)), + forward_only=True, + ) + + def _test_cmp(self, fxn, reverse=True): + for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: + helper_test_op(shps, fxn, fxn, forward_only=True) + helper_test_op( + None, fxn, fxn, forward_only=True, vals=[[0.0, 1, 2], [2.0, 1, 0]] + ) + helper_test_op( + None, + lambda x, y: fxn(x, 2), + lambda x, y: fxn(x, 2), + forward_only=True, + vals=[[0.0, 1, 2], [2.0, 1, 0]], + ) + helper_test_op( + None, + fxn, + fxn, + forward_only=True, + vals=[[True, True, False], [False, True, False]], + ) + if reverse: + helper_test_op( + None, + lambda x, y: fxn(2, y), + lambda x, y: fxn(2, y), + forward_only=True, + vals=[[0.0, 1, 2], [2.0, 1, 0]], + ) + + def test_cmp_eq(self): + self._test_cmp(lambda x, y: x == y, reverse=False) + + def test_cmp_gt(self): + self._test_cmp(lambda x, y: x > y) + + def test_cmp_ge(self): + self._test_cmp(lambda x, y: x >= y) + + def test_cmp_lt(self): + self._test_cmp(lambda x, y: x < y) + + def test_cmp_le(self): + self._test_cmp(lambda x, y: x <= y) + + def test_cmp_eq_backwards(self): + t1 = torch.ones(4, requires_grad=True) + t2 = torch.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (t1 == t2).sum().backward) + tt1 = Tensor.ones(4, requires_grad=True) + tt2 = Tensor.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (tt1 == tt2).sum().backward) + + def test_cmp_lt_backwards(self): + t1 = torch.ones(4, requires_grad=True) + t2 = torch.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (t1 < t2).sum().backward) + tt1 = Tensor.ones(4, requires_grad=True) + tt2 = Tensor.ones(4, requires_grad=True) + self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + + # @unittest.skip("this is broken with contiguous") + def test_trunc(self): + helper_test_op( + [(45, 65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True + ) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor( + [1.0, 2.1, 0.0, -5.0, -2.5] + ) + helper_test_op( + [], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True + ) + + # @unittest.skip("this is broken with contiguous") + def test_floor(self): + helper_test_op( + [(45, 65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True + ) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor( + [1.0, 2.1, 0.0, -5.0, -2.5] + ) + helper_test_op( + [], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True + ) + + # @unittest.skip("this is broken with contiguous") + def test_ceil(self): + helper_test_op( + [(45, 65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True + ) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor( + [1.0, 2.1, 0.0, -5.0, -2.5] + ) + helper_test_op( + [], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True + ) + + def test_tril(self): + helper_test_op([(3, 3)], lambda x: x.tril(), lambda x: x.tril()) + helper_test_op([(3, 3)], lambda x: x.tril(1), lambda x: x.tril(1)) + helper_test_op([(3, 3)], lambda x: x.tril(-1), lambda x: x.tril(-1)) + helper_test_op([(5, 3, 3)], lambda x: x.tril(), lambda x: x.tril()) + helper_test_op([(5, 3, 3)], lambda x: x.tril(1), lambda x: x.tril(1)) + + def test_triu(self): + helper_test_op([(3, 3)], lambda x: x.triu(), lambda x: x.triu()) + helper_test_op([(3, 3)], lambda x: x.triu(1), lambda x: x.triu(1)) + helper_test_op([(3, 3)], lambda x: x.triu(-1), lambda x: x.triu(-1)) + helper_test_op([(5, 3, 3)], lambda x: x.triu(), lambda x: x.triu()) + helper_test_op([(5, 3, 3)], lambda x: x.triu(1), lambda x: x.triu(1)) + + def test_maximum(self): + helper_test_op([(45, 65), (45, 65)], torch.maximum, Tensor.maximum) + helper_test_op([(), ()], torch.maximum, Tensor.maximum) + helper_test_op( + None, + torch.maximum, + Tensor.maximum, + vals=[[1.0, 0.0, 3.0, 4.0], [1.0, 2.0, 3.0, 0.0]], + ) + helper_test_op( + None, + torch.maximum, + Tensor.maximum, + vals=[[1, 0, 3, 4], [1, 2, 3, 0]], + forward_only=True, + ) + + def test_minimum(self): + helper_test_op([(45, 65), (45, 65)], torch.minimum, Tensor.minimum) + helper_test_op([(), ()], torch.minimum, Tensor.minimum) + + def test_add(self): + helper_test_op([(45, 68), (45, 68)], lambda x, y: x + y, Tensor.add) + + def test_add_number(self): + helper_test_op([(), ()], lambda x, y: x + y, Tensor.add) + + def test_add3(self): + helper_test_op([(45, 65), (45, 65), (45, 65)], lambda x, y, z: x + y + z) + + def test_add_simple(self): + helper_test_op( + [(256), (256)], lambda x, y: x + y, Tensor.add, forward_only=True + ) + + def test_broadcasted_add(self): + helper_test_op([(45, 65), (45, 1)], lambda x, y: x + y, lambda x, y: x + y) + helper_test_op([(45, 65), ()], lambda x, y: x + y, lambda x, y: x + y) + + def test_broadcasted_add_2(self): + helper_test_op([(45, 65), (65,)], lambda x, y: x + y, lambda x, y: x + y) + + def test_sub(self): + helper_test_op([(45, 65), (45, 65)], lambda x, y: x - y, Tensor.sub) + helper_test_op([(), ()], lambda x, y: x - y, Tensor.sub) + + def test_neg(self): + helper_test_op([(45, 65)], lambda x: -x) + helper_test_op([()], lambda x: -x) + + def test_mul(self): + helper_test_op([(64, 64), (64, 64)], lambda x, y: x * y, Tensor.mul) + + def test_mul_number(self): + helper_test_op([(), ()], lambda x, y: x * y, Tensor.mul) + + def test_mul_const(self): + helper_test_op([(45, 65)], lambda x: x * 2, lambda x: x * 2) + helper_test_op([(45, 65)], lambda x: x * -1, lambda x: x * -1) + helper_test_op([(45, 65)], lambda x: 255 * x, lambda x: 255 * x) + + def test_div(self): + helper_test_op([(45, 65), (45, 65)], lambda x, y: x / y, Tensor.div) + helper_test_op([(), ()], lambda x, y: x / y, Tensor.div) + helper_test_op( + None, lambda x, y: x / y, Tensor.div, forward_only=True, vals=[[5], [1]] + ) + + def test_div_int(self): + helper_test_op( + None, + lambda x: (x / 2).to(torch.int), + lambda x: x / 2, + forward_only=True, + vals=[[3]], + ) + + def test_div_const(self): + helper_test_op([(45, 65)], lambda x: x / 255, lambda x: x / 255) + helper_test_op([(45, 65)], lambda x: x / 1, lambda x: x / 1) + helper_test_op([(45, 65)], lambda x: 1 / x, lambda x: 1 / x) + helper_test_op([(45, 65)], lambda x: x / 2, lambda x: x / 2) + helper_test_op([(45, 65)], lambda x: 2 / x, lambda x: 2 / x) + helper_test_op([()], lambda x: x / 2, lambda x: x / 2) + helper_test_op([()], lambda x: 2 / x, lambda x: 2 / x) + + @unittest.skipIf( + Device.DEFAULT in ["METAL", "WEBGPU"], + "WEBGPU does not have support for inf/nan, METAL has issues with -inf", + ) + def test_mul_const_naninf(self): + helper_test_op( + [(45, 65)], lambda x: x * float("inf"), lambda x: x * float("inf") + ) + helper_test_op( + [(45, 65)], lambda x: x * -float("inf"), lambda x: x * -float("inf") + ) + helper_test_op( + [(45, 65)], lambda x: x * float("nan"), lambda x: x * float("nan") + ) + + @unittest.skipIf( + Device.DEFAULT in ["METAL", "WEBGPU"], + "WEBGPU does not have support for inf/nan, METAL has issues with -inf", + ) + def test_div_const_naninf(self): + helper_test_op( + [(45, 65)], lambda x: x / float("inf"), lambda x: x / float("inf") + ) + helper_test_op( + [(45, 65)], lambda x: x / -float("inf"), lambda x: x / -float("inf") + ) + helper_test_op( + [(45, 65)], lambda x: x / float("nan"), lambda x: x / float("nan") + ) + helper_test_op( + [(45, 65)], lambda x: float("inf") / x, lambda x: float("inf") / x + ) + helper_test_op( + [(45, 65)], lambda x: (-float("inf")) / x, lambda x: (-float("inf")) / x + ) + helper_test_op( + [(45, 65)], lambda x: float("nan") / x, lambda x: float("nan") / x + ) + + def test_pow_full(self): + helper_test_op([(45, 65), (45, 65)], lambda x, y: x**y, Tensor.pow, a=0) + + def test_pow(self): + # TODO: why is a=0 for these tests? + helper_test_op([(45, 65)], lambda x: x**2, lambda x: Tensor.pow(x, 2), a=0) + helper_test_op([(45, 65)], lambda x: x**3, lambda x: Tensor.pow(x, 3), a=0) + helper_test_op([(45, 65)], lambda x: x**-2, lambda x: Tensor.pow(x, -2), a=0) + helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x, 2), a=0) + helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x, -2), a=0) + # Regression tests for https://github.com/tinygrad/tinygrad/issues/1151 + helper_test_op([(45, 65)], lambda x: x**3, lambda x: Tensor.pow(x, 3), a=-10) + helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x, 3), a=-10) + # Regression tests for https://github.com/tinygrad/tinygrad/issues/1251 + helper_test_op( + [(45, 65)], lambda x: x**0.2, lambda x: Tensor.pow(x, 0.2), a=-10 + ) + helper_test_op( + [(45, 65)], lambda x: x**1.2, lambda x: Tensor.pow(x, 1.2), a=-10 + ) + helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x, 0.2), a=-10) + helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x, 1.2), a=-10) + a, b = Tensor([0.0], requires_grad=True), torch.tensor( + [0.0], requires_grad=True + ) + helper_test_op( + [], + lambda: b**1.1, + lambda: a**1.1, + ) + + def test_pow_const(self): + helper_test_op([(45, 65)], lambda x: x**1.0, lambda x: x**1.0) + helper_test_op([(45, 65)], lambda x: x**-1.0, lambda x: x**-1.0) + helper_test_op([(45, 65)], lambda x: 1.0**x, lambda x: 1.0**x) + helper_test_op([(45, 65)], lambda x: x**2.0, lambda x: x**2.0) + helper_test_op([(45, 65)], lambda x: 2.0**x, lambda x: 2.0**x) + helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0) + helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x) + + def test_sqrt(self): + helper_test_op([(45, 65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) + helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0) + + def test_rsqrt(self): + helper_test_op([(45, 65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + + def test_sin(self): + helper_test_op([(45, 65)], lambda x: x.sin(), Tensor.sin, a=0) + + def test_cos(self): + helper_test_op([(45, 65)], lambda x: x.cos(), Tensor.cos, a=0) + + def test_tan(self): + helper_test_op([(45, 65)], lambda x: x.tan(), Tensor.tan, a=0) + + def test_relu(self): + helper_test_op([(64, 64)], lambda x: x.relu(), Tensor.relu) + helper_test_op([()], lambda x: x.relu(), Tensor.relu) + + def test_relu_exact(self): + helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.0, 0, 1]]) + + def test_relu_maximum_exact(self): + helper_test_op( + None, + lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), + lambda x: Tensor.maximum(x, 0), + vals=[[-1.0, 0, 1]], + ) + + def test_leakyrelu(self): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.leaky_relu(x, 0.01), + Tensor.leakyrelu, + ) + helper_test_op( + [()], lambda x: torch.nn.functional.leaky_relu(x, 0.01), Tensor.leakyrelu + ) + + def test_celu(self): + for val in range(1, 5): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.celu(x, val), + lambda x: x.celu(val), + ) + helper_test_op( + [()], lambda x: torch.nn.functional.celu(x, val), lambda x: x.celu(val) + ) + + def test_abs(self): + helper_test_op([(45, 65)], lambda x: torch.abs(x), Tensor.abs) + helper_test_op([()], lambda x: torch.abs(x), Tensor.abs) + + def test_log(self): + helper_test_op([(45, 65)], lambda x: torch.log(x), Tensor.log) + helper_test_op([()], lambda x: torch.log(x), Tensor.log) + + def test_log2(self): + helper_test_op([(45, 65)], lambda x: torch.log2(x), Tensor.log2) + helper_test_op([()], lambda x: torch.log2(x), Tensor.log2) + + def test_exp(self): + helper_test_op([(45, 65)], lambda x: torch.exp(x), Tensor.exp) + helper_test_op([()], lambda x: torch.exp(x), Tensor.exp) + + def test_exp2(self): + helper_test_op([(45, 65)], lambda x: torch.exp2(x), Tensor.exp2) + helper_test_op([()], lambda x: torch.exp2(x), Tensor.exp2) + + def test_sign(self): + helper_test_op([(45, 65)], lambda x: torch.sign(x), Tensor.sign) + helper_test_op([()], lambda x: torch.sign(x), Tensor.sign) + + def test_softsign(self): + helper_test_op( + [(45, 65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign + ) + helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) + + def test_sigmoid(self): + helper_test_op([(45, 65)], lambda x: x.sigmoid(), Tensor.sigmoid) + helper_test_op([(45, 65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=100) + helper_test_op([(45, 65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100) + helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True) + + def test_softplus(self): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.softplus(x), + Tensor.softplus, + atol=1e-6, + grad_atol=1e-6, + ) + helper_test_op( + [()], + lambda x: torch.nn.functional.softplus(x), + Tensor.softplus, + atol=1e-6, + grad_atol=1e-6, + ) + + def test_gelu(self): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + Tensor.gelu, + ) + # helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + Tensor.gelu, + a=-100, + ) + + def test_quick_gelu(self): + helper_test_op( + [(45, 65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu + ) + helper_test_op( + [(45, 65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=100 + ) + helper_test_op( + [(45, 65)], + lambda x: x * torch.sigmoid(1.702 * x), + Tensor.quick_gelu, + a=-100, + ) + helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) + + def test_elu(self): + helper_test_op([(45, 65)], lambda x: torch.nn.functional.elu(x), Tensor.elu) + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.elu(x, alpha=0.1), + lambda x: Tensor.elu(x, alpha=0.1), + ) + helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu) + + def test_relu6(self): + helper_test_op([(45, 65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + + def test_hardswish(self): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.hardswish(x), + Tensor.hardswish, + atol=1e-6, + grad_atol=1e-6, + ) + helper_test_op( + [()], + lambda x: torch.nn.functional.hardswish(x), + Tensor.hardswish, + atol=1e-6, + grad_atol=1e-6, + ) + + def test_mish(self): + def _mish_pytorch(x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + helper_test_op([(45, 65)], _mish_pytorch, Tensor.mish, atol=1e-4) + helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4) + + @unittest.skipIf(IMAGE > 0, "no 1d dot for images") + def test_dot_1d(self): + helper_test_op([(65), (65)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4) + helper_test_op( + [(65), (65, 45)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + helper_test_op( + [(45, 65), (65)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + helper_test_op( + [(8, 45, 65), (65)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + helper_test_op( + [(65), (8, 65, 45)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + self.helper_test_exception( + [(4), (1, 2)], + lambda x, y: x.matmul(y), + Tensor.dot, + expected=(RuntimeError, AssertionError), + ) + self.helper_test_exception( + [(2, 1), (4)], + lambda x, y: x.matmul(y), + Tensor.dot, + expected=(RuntimeError, AssertionError), + ) + self.helper_test_exception( + [(1), (4)], + lambda x, y: x.matmul(y), + Tensor.dot, + expected=(RuntimeError, AssertionError), + ) + + def test_dot(self): + helper_test_op( + [(45, 65), (65, 100)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + helper_test_op( + [(8, 45, 65), (8, 65, 100)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + self.helper_test_exception( + [(2, 4), (1, 3)], + lambda x, y: x.matmul(y), + Tensor.dot, + expected=(RuntimeError, AssertionError), + ) + self.helper_test_exception( + [(2, 1), (4, 3)], + lambda x, y: x.matmul(y), + Tensor.dot, + expected=(RuntimeError, AssertionError), + ) + with self.assertRaises(AssertionError): + a = Tensor(3.14) + a.matmul(a) + + def test_multinomial(self): + # NOTE: this is random, so it has a very large atol + helper_test_op( + [(1000,)], + lambda x: torch.multinomial(x.clip(0, 1), num_samples=1), + lambda x: Tensor.multinomial(x.clip(0, 1)), + forward_only=True, + atol=1000.0, + ) + + def test_small_cumsum(self): + helper_test_op( + [(10)], + lambda x: torch.cumsum(x, dim=0), + lambda x: Tensor.cumsum(x, axis=0), + atol=1e-6, + ) + + def test_simple_cumsum(self): + helper_test_op( + [(1022)], + lambda x: torch.cumsum(x, dim=0), + lambda x: Tensor.cumsum(x, axis=0), + atol=1e-6, + ) + + def test_cumsum(self): + helper_test_op( + [(20)], + lambda x: torch.cumsum(x, dim=0), + lambda x: Tensor.cumsum(x, axis=0), + atol=1e-6, + ) + helper_test_op( + [(20, 30)], + lambda x: torch.cumsum(x, dim=0), + lambda x: Tensor.cumsum(x, axis=0), + atol=1e-6, + ) + helper_test_op( + [(20, 30)], + lambda x: torch.cumsum(x, dim=1), + lambda x: Tensor.cumsum(x, axis=1), + atol=1e-6, + ) + helper_test_op( + [(20, 30, 40)], + lambda x: torch.cumsum(x, dim=2), + lambda x: Tensor.cumsum(x, axis=2), + atol=1e-6, + ) + helper_test_op( + [(20, 30, 40)], + lambda x: torch.cumsum(x, dim=-1), + lambda x: Tensor.cumsum(x, axis=-1), + atol=1e-6, + ) + + def test_argmax(self): + self.assertEqual( + torch.Tensor([2, 2]).argmax().numpy(), Tensor([2, 2]).argmax().numpy() + ) # check if returns first index for same max + helper_test_op( + [(10, 20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmax(0, False), + lambda x: x.argmax(0, False), + forward_only=True, + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmax(1, False), + lambda x: x.argmax(1, False), + forward_only=True, + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmax(1, True), + lambda x: x.argmax(1, True), + forward_only=True, + ) + + def test_argmin(self): + self.assertEqual( + torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy() + ) + helper_test_op( + [(10, 20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmin(0, False), + lambda x: x.argmin(0, False), + forward_only=True, + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmin(1, False), + lambda x: x.argmin(1, False), + forward_only=True, + ) + helper_test_op( + [(10, 20)], + lambda x: x.argmin(1, True), + lambda x: x.argmin(1, True), + forward_only=True, + ) + + def test_matmul_simple(self): + helper_test_op([(4), (4, 4)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4) + + def test_matmul(self): + helper_test_op( + [(64), (64, 99)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + + @unittest.skipIf(IMAGE > 0, "no batched matmul on images") + def test_matmul_batched(self): + helper_test_op( + [(3), (1, 3, 3, 5)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + + @unittest.skipIf(IMAGE > 0, "no batched matmul on images") + def test_matmul_batched_vector(self): + helper_test_op( + [(4, 3), (1, 3, 3, 5)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-4 + ) + + def test_small_gemm(self): + helper_test_op( + [(8, 8), (8, 8)], lambda x, y: x.matmul(y), lambda x, y: x @ y, atol=1e-3 + ) + + def test_small_gemm_eye(self): + helper_test_op( + None, + lambda x, y: x.matmul(y), + lambda x, y: x @ y, + atol=1e-3, + vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)], + ) + + def test_gemm(self): + helper_test_op( + [(64, 64), (64, 64)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-3 + ) + + def test_big_gemm(self): + helper_test_op( + [(256, 256), (256, 256)], lambda x, y: x.matmul(y), Tensor.dot, atol=1e-3 + ) + + def test_broadcastdot(self): + helper_test_op( + [(10, 45, 65), (65, 45)], lambda x, y: x @ y, Tensor.dot, atol=1e-4 + ) + with self.assertRaises(AssertionError): + a = Tensor(3.14) + b = Tensor.ones(3, 3) + a @ b + + def test_multidot(self): + helper_test_op( + [(10, 45, 65), (10, 65, 45)], lambda x, y: x @ y, Tensor.dot, atol=1e-4 + ) + helper_test_op( + [(3, 3, 45, 65), (3, 3, 65, 45)], lambda x, y: x @ y, Tensor.dot, atol=1e-4 + ) + + def test_sum_simple(self): + helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.0, 1.0]]) + + def test_sum_full(self): + helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum()) + + def test_sum_small_full(self): + helper_test_op([(45, 5)], lambda x: x.sum(), Tensor.sum) + + def test_sum_relu(self): + helper_test_op( + [(3, 4, 5)], + lambda x: x.relu().sum().relu(), + lambda x: x.relu().sum().relu(), + ) + + def test_sum(self): + helper_test_op([(45, 3)], lambda x: x.sum(), Tensor.sum) + helper_test_op( + [(3, 4, 5, 6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3) + ) + helper_test_op( + [(3, 4, 5, 6)], + lambda x: x.sum(axis=(1, 3)), + lambda x: Tensor.sum(x, axis=(1, 3)), + ) + helper_test_op( + [(3, 4, 5, 6)], + lambda x: x.sum(axis=(0, 2)), + lambda x: Tensor.sum(x, axis=(0, 2)), + ) + helper_test_op( + [(3, 4, 5, 6)], + lambda x: x.sum(axis=(1, 2)), + lambda x: Tensor.sum(x, axis=(1, 2)), + ) + helper_test_op( + [(3, 4, 5, 6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1) + ) + helper_test_op([()], lambda x: x.sum(), Tensor.sum) + + def test_min(self): + helper_test_op([(3, 3)], lambda x: x.min(), Tensor.min) + helper_test_op([(45, 3)], lambda x: x.min(), Tensor.min) + helper_test_op( + [(45, 3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5) + ) + helper_test_op([()], lambda x: x.min(), Tensor.min) + + def test_max(self): + helper_test_op([(45, 3)], lambda x: x.max(), Tensor.max) + helper_test_op( + [(45, 3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5) + ) + helper_test_op( + None, + lambda x: x.max().mul(0.5), + lambda x: Tensor.max(x).mul(0.5), vals=[ - [[1.0,1.0,0.0,1.0]], - ]) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) - helper_test_op([()], lambda x: x.max(), Tensor.max) - def test_mean(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean()) - helper_test_op([()], lambda x: x.mean()) - def test_mean_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) - def test_std(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5)) - def test_std_axis(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2])) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None), lambda x: Tensor.std(x, axis=None)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=0), lambda x: Tensor.std(x, axis=0, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=2), lambda x: Tensor.std(x, axis=2, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0)) - def test_std_keepdim(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0)) - def test_log_softmax(self): - helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) - helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) - def test_log_softmax_other_axis(self): - helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) - def test_tanh(self): - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) - def test_hardtanh(self): - for val in range(10, 30, 5): - helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) - def test_topo_sort(self): - helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) - - def test_scalar_mul(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) - helper_test_op([()], lambda x: x*2, lambda x: x*2) - def test_scalar_rmul(self): - helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x) - helper_test_op([()], lambda x: 2*x, lambda x: 2*x) - def test_scalar_sub(self): - helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2) - helper_test_op([()], lambda x: x-2, lambda x: x-2) - def test_scalar_rsub(self): - helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) - helper_test_op([()], lambda x: 2-x, lambda x: 2-x) - def test_flip_eye_crash(self): - helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), - lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) - - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs passing the WEBGPU limit") #TODO: remove after #1461 - def test_broadcast_full(self): - for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: - for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: - with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) - - def test_broadcast_simple(self): - helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y) - helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y) - - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs passing the WEBGPU limit") #TODO: remove after #1461 - def test_broadcast_partial(self): - for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: - for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)), - ((4,1), (4,5)), ((1,4), (5,4))]: - with self.subTest(op=torch_op.__name__, shapes=shapes): - # NOTE: ANE backwards? - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) - - def test_slice_in_bounds_1dim(self): - helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3]) - helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2]) - helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2]) - - def test_slice_on_0dim_tensor(self): - helper_test_op([()], lambda x: x[None], lambda x: x[None]) - - with self.assertRaises(IndexError): - a = Tensor(3.14) - a[0] - - def test_slice_int_indexing(self): - helper_test_op([(3)], lambda x: x[1], lambda x: x[1]) - helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2]) - helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1]) - helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1]) - - def test_slice_in_bounds_multidim(self): - helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) - - def test_slice_with_none(self): - helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1]) - - def test_slice_one_endpoint_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4]) - helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4]) - helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50]) - helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1]) - - def test_slice_stride_gt_one(self): - helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4]) - - def test_slice_negative_strides(self): - # Torch doesn't support slicing with negative steps - a = np.random.randn(10, 10, 10).astype(np.float32) - t = Tensor(a) - np.testing.assert_allclose(a[::-1], t[::-1].numpy()) - np.testing.assert_allclose(a[::-2], t[::-2].numpy()) - np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) - np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) - np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) - if Device.DEFAULT != "CPU": - # broken - np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) - - def test_slice_both_endpoints_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True) - - def test_slice_start_gt_end(self): - helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) - - def test_slice_empty(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) - - def test_slice_zero_in_shape(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) # x.shape = (0, 10) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) # x.shape = (0, 3, 3) - - def test_slice_errors(self): - a = Tensor.ones(4, 3) - with self.assertRaises(IndexError): - a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds) - a[1, 77] # IndexError: (out of bounds). - a[0, -77] - a[..., ...] # IndexError: only single ellipsis - - def test_slice_ellipsis(self): - helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ...], lambda x: x[0, ...]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0], lambda x: x[0, ..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3], lambda x: x[0:3, ..., 2:3]) - helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None], lambda x: x[None, 0:3, ..., 0, None]) - - def test_pad2d(self): - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4))) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5)) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5)) - - def test_pad(self): - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf"))) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf"))) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1)) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1)) - - def test_pad_slice(self): - for value in 0., 3.456: - helper_test_op([(1)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) - helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) - helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(3,0), value=value)[0:1], lambda x: x.pad(((3,0),), value=value)[0:1]) - helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[6], lambda x: x.pad(((0,3),), value=value)[6]) - helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[4:6], lambda x: x.pad(((0,3),), value=value)[4:6]) - helper_test_op([(5,5)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0], lambda x: x.pad(((1,0),(0,0)), value=value)[0]) - helper_test_op([(2,2)], lambda x: torch.nn.functional.pad(x,(0,1,0,0), value=value)[0,2], lambda x: x.pad(((0,0),(0,1)), value=value)[0,2]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0,2], lambda x: x.pad(((1,0),(0,0)), value=value)[0,2]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[5], lambda x: x.pad(((0,2),(0,0)), value=value)[5]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[3:5], lambda x: x.pad(((0,2),(0,0)), value=value)[3:5]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0:4], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0:4]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[0], lambda x: x.pad(((1,2),(3,4)), value=value)[0]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,1], lambda x: x.pad(((1,2),(3,4)), value=value)[:,1]) - helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,4], lambda x: x.pad(((1,2),(3,4)), value=value)[:,4]) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,3,0,0), value=value)[:,4:6], lambda x: x.pad(((0,0),(0,3)), value=value)[:,4:6]) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,1,3,2), value=value)[0:2,:], lambda x: x.pad(((3,2),(0,1)), value=value)[0:2,:]) - helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[0:2,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[0:2,:,:]) - helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[2:4,:,:], lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:]) - - def test_stack_slice(self): - helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:]) - helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack([x for i in range(3)])[0,0]) - helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3]) - - def test_transpose(self): - helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) - helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) - helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1))) - helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0))) - helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(())) - - def test_reshape(self): - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) - helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1])) - - with self.assertRaises(ValueError): - x = Tensor.ones((4,3,6,6)) - x.reshape([]) - - def test_flip(self): - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0)) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,))) - helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - - def test_squeeze(self): - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - helper_test_op([(4,3,1,6)], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x, 3), lambda x: x.squeeze(dim=3)) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError, exact=True) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError, exact=True) - helper_test_op([(4,3,6,1)], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(2,3,1)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([()], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([()], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError, exact=True) - helper_test_op([()], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - - def test_unsqueeze(self): - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3)) - helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) - - def test_flatten(self): - for axis in range(3): - helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis)) - helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten()) - helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten()) - - def test_detach(self): - helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) - helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) - - def test_expand(self): - arg = (4,3,2,6) - helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) - helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[])) - - @unittest.skip("very slow") - def test_sd_big_conv(self): - # internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int - helper_test_op([(1,256,64,64), (512,256,3,3)], - lambda x,w: torch.nn.functional.conv2d(x, w), - lambda x,w: x.conv2d(w), atol=1e-2) - - @unittest.skip("slow") - def test_large_bs_conv(self): - # large batch size can cause OpenCL image to exceed max image height on macOS - # (or cause the conv kernel to overflow short sampling coords) - helper_test_op([(4096,3,3,3), (1,3,3,3)], - lambda x,w: torch.nn.functional.conv2d(x, w), - lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2) - - @unittest.skip("slow") - def test_large_ic_conv(self): - # large input channel count can cause OpenCL image to exceed max image width on macOS - helper_test_op([(1,2048,3,3), (1,2048,3,3)], - lambda x,w: torch.nn.functional.conv2d(x, w), - lambda x,w: x.conv2d(w), atol=1e-4) - - def test_biased_conv2d(self): - C = 8 - helper_test_op([(1,C,5,5), (C,C,1,1), (C,)], - lambda x,w,b: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w,b).relu(),w,b), - lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b), atol=1e-4) - - def test_simple_conv2d(self): - helper_test_op([(1,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - @unittest.skipIf(IMAGE>0, "no conv3d on images") - def test_simple_conv3d(self): - helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], - lambda x,w: torch.nn.functional.conv3d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - @unittest.skipIf(IMAGE>0, "no conv3d on images") - def test_padded_conv3d(self): - helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], - lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_simple_conv2d_m4(self): - helper_test_op([(1,16,18,18), (16,16,3,3)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_simple_conv2d_1x1(self): - helper_test_op([(1,4,9,9), (4,4,1,1)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_simple_conv2d_1x1_m4(self): - helper_test_op([(1,16,32,32), (16,16,1,1)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_nested_conv2d(self): - helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)], - lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(), - lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu(), atol=1e-4, grad_rtol=1e-5) - - # expect reduce nodes == 3 - def test_simple_conv2d_nhwc(self): - # weights (from tf): filter_height x filter_width x in_channels x out_channels - helper_test_op([(2,9,9,10), (3,3,10,20)], - lambda x,w: torch.nn.functional.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)).relu(), - lambda x,w: Tensor.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_simple_conv2d_batched(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv2d(x,w).relu(), - lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - # conv transpose - - def test_simple_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_bias_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3), (4,)], - lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b).relu(), - lambda x,w,b: Tensor.conv_transpose2d(x,w,b).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_grouped_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_padded_conv_transpose2d(self): - for padding in [(1,2), (2,1), 2, 1, 0]: - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_dilated_conv_transpose2d(self): - for dilation in [(1,2), (2,1), 2, 1]: - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=dilation).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_strided_conv_transpose2d(self): - for stride in [(2,1), (1,2), 1]: - helper_test_op([(2,4,4,5), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_output_padded_conv_transpose2d(self): - for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: - helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], - lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), - lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) - - @unittest.skipIf(IMAGE>0, "no conv3d on images") - def test_simple_conv_transpose3d(self): - helper_test_op([(2,4,9,9,9), (4,4,3,3,3)], - lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - - @unittest.skipIf((IMAGE>0), "no conv1d on images") - def test_conv1d(self): - for bs in [1,8]: - for cin in [1,3]: - for H in [1,2,5]: - for groups in [1,3] if cin == 3 and H == 5 else [1]: - with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H): - helper_test_op([(bs,cin,11), (6,cin//groups,H)], - lambda x,w: torch.nn.functional.conv1d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - @unittest.skipIf(IMAGE>0, "no conv1d on images") - def test_simple_padding_conv1d(self): - bs = 6 - cin = 2 - groups = 1 - H = 5 - p = (1,1) - helper_test_op([(bs,cin,11), (6,cin//groups,H)], - lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - - @unittest.skipIf(IMAGE>0, "no conv1d on images") - def test_strided_conv1d_simple(self): - bs, H = 2, 3 - helper_test_op([(bs,1,5), (1,1,H)], - lambda x,w: torch.nn.functional.conv1d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4) - - @unittest.skipIf(IMAGE>0, "no conv1d on images") - def test_asymmetric_padding_conv1d(self): - for p in [(0,1), (2,1), (2,0)]: - with self.subTest(p): - for n in [3,4]: - for k in [2]: - helper_test_op([(1,1,n), (1,1,k)], - lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - helper_test_op([(1,1,n), (1,1,k)], - lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - - def _test_conv2d(self, bs=1, cin=1): - for H in [1,2,3]: - for W in [1,2,3,5]: - for groups in [1,3] if cin == 3 and H == 3 and W == 3 else [1]: - with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W): - helper_test_op([(bs,cin,11,7), (6,cin//groups,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - def test_conv2d(self): self._test_conv2d(bs=1, cin=3) - def test_conv2d_bs_4_cin_3(self): self._test_conv2d(bs=4, cin=3) - def test_conv2d_bs_1_cin_1(self): self._test_conv2d(bs=1, cin=1) - def test_conv2d_bs_4_cin_1(self): self._test_conv2d(bs=4, cin=1) - - def test_large_input_conv2d(self): - bs = 4 - cin = 16 - groups = 1 - H = 5 - W = 2 - helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - # needed to relax tolerance on NVIDIA - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5) - - def test_simple_grouped_conv2d(self): - bs = 1 - groups = 2 - rcout = 1 - cin = 2 - helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_medium_grouped_conv2d(self): - bs = 1 - groups = 2 - rcout = 2 - cin = 2 - helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_depthwise_conv2d(self): - bs = 1 - groups = 32 - rcout = 1 - cin = 1 - helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_grouped_conv2d(self): - bs = 4 - groups = 5 - rcout = 7 - cin = 3 - helper_test_op([(bs,groups*cin,5,5), (groups*rcout,cin,3,3)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_fancy_conv2d(self): - bs = 2 - cin = 3 - cout = 1 - groups = 3 - H,W = 3,3 - helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) - - def test_strided_conv2d_simple(self): - bs,H,W = 2,3,1 - helper_test_op([(bs,1,5,1), (1,1,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4) - - def test_strided_conv2d(self): - bs = 4 - cin = 3 - H,W = 3,3 - with self.subTest(stride := 2): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), atol=1e-4) - with self.subTest(stride := (2,1)): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4) - - def test_negative_padding_conv2d(self): - n,k = 10, 3 - helper_test_op([(1,1,n,n), (1,1,k,k)], - lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1],w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=-1).relu(), atol=1e-4) - helper_test_op([(1,1,n,n), (1,1,k,k)], - lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:, 1:],w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0)).relu(), atol=1e-4) - - def test_simple_padding_conv2d(self): - p = (1,1,1,1) - helper_test_op(None, - lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4, vals=[[[[[2.,3.]]]], [[[[1.]]]]]) - - def test_asymmetric_padding_conv2d(self): - for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: - with self.subTest(p): - for n in [3,4]: - for k in [2]: - helper_test_op([(1,1,n,n), (1,1,k,k)], - lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - helper_test_op([(1,1,n,n), (1,1,k,k)], - lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - - def test_padded_conv2d_p21(self): - bs,cin,H,W,padding = 4, 3, 3, 3, (2,1) - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - - def test_padded_conv2d_p22(self): - bs,cin,H,W,padding = 4, 3, 3, 3, (2,2) - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - - def test_padded_conv2d_1x1(self): - bs,cin,H,W,padding = 4, 3, 1, 1, 2 - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - - def test_padded_conv2d_bs1(self): - bs,cin,H,W,padding = 1, 3, 3, 3, 1 - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - - def test_padding_add(self): - helper_test_op([(64,64), (60,60)], - lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)), - lambda x,w: x+w.pad2d((2,2,2,2))) - - def test_dilated_conv2d(self): - bs = 4 - cin = 3 - H,W = 3,3 - for d in [2, (2,1)]: - with self.subTest(dilation := d): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), - lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4) - - def test_maxpool2d_simple(self): - ksz = (2,2) - helper_test_op([(1,1,2,3)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), - lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) - - def test_maxpool2d(self): - for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]: - with self.subTest(kernel_size=ksz): - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), - lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) - - def test_maxpool2d_bigger_stride(self): - for stride in [(2,3), (3,2), 2, 3]: - with self.subTest(stride=stride): - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), - lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) - - @unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this") - def test_maxpool2d_unit_stride(self): - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), - lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) - - def test_maxpool2d_smaller_stride(self): - for stride in [(2,3), (3,2), 2, 3]: - with self.subTest(stride=stride): - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), - lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) - - def test_maxpool2d_dilation(self): - for dilation in [(2, 3), (3, 2), 2, 3]: - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation), - lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation)) - - def test_avgpool2d(self): - shape = (32,2,111,28) - for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: - with self.subTest(kernel_size=ksz): - helper_test_op([shape], - lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), - lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) - - def test_global_avgpool2d(self): - helper_test_op([(32,2,111,28)], - lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)), - lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) - - def test_cat(self): - for dim in range(-2, 3): - helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) - - with self.assertRaises(AssertionError): - a = Tensor(3.14) - a.cat(a) - - def test_multicat(self): - for dim in range(-1, 2): - helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) - - def test_stack(self): - x = Tensor.randn(45, 65, 3) - - for dim in range(-1, 3): - helper_test_op([(45, 65, 3), (45, 65, 3), (45, 65, 3)], lambda x, y, z: torch.stack((x, y, z), dim=dim), lambda x, y, z: Tensor.stack([x, y, z], dim=dim)) - - with self.assertRaises(IndexError): - Tensor.stack([x], dim=77) - - a = Tensor(3.14) - np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy()) - - def test_repeat(self): - x = Tensor.randn(4, 6, 3) - base_repeats = [2, 4, 3] - - for reps in [[], [4], [2, 1], [3, 2, 2]]: - repeats = base_repeats + reps - helper_test_op([(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) - helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) - - with self.assertRaises(ValueError): - x.repeat((2, 4)) - - np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy()) - - def test_clip(self): - helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) - - def test_matvecmat(self): - helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4) - - def test_matvec(self): - helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu(), atol=1e-4) - - # this was the failure in llama early realizing freqs_cis - def test_double_slice(self): - helper_test_op([(4,4)], lambda x: x[:, 1:2][1:2]) - helper_test_op([(4,4)], lambda x: x[1:3][1:2]) - helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1]) - helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1]) - - @unittest.skip("this test is broken #862") - def test_max_inf(self): - n = Tensor([1, float("nan")]).max().numpy() - assert math.isnan(n.item()), f"{n.item()} is not nan" - - def test_inf_where(self): - x = Tensor.full((3, 3), float("inf")) - n = (x < 0).where(x, 1).numpy() - assert np.all(n == 1.) - - def _get_index_randoms(self): - # indices cannot have gradient - # TODO currently does not support IndexError for out of bounds idx values - a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False) - b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False) - c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False) - d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False) - e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False) - i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]] - return a,b,c,d,e,i,j,k,o,p - - def test_slice_fancy_indexing_no_dim_collapse(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() - # no dim collapse from int or dim injection from None - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,d,e], lambda x: x[i,j,k,o,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[:,b,c,d,:], lambda x: x[:,j,k,o,:]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,...], lambda x: x[i,j,...]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,...,e], lambda x: x[i,...,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,c,:,e], lambda x: x[...,k,:,p]) - - def test_slice_fancy_indexing_dim_collapse_int(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() - # dim collapse from int - helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,c,d,e], lambda x: x[1,j,k,o,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,3,d,e], lambda x: x[i,j,3,o,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,2,d,2], lambda x: x[1,j,2,o,2]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,2,2,2,e], lambda x: x[i,2,2,2,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,:,3:11:2,d,0:2], lambda x: x[1,:,3:11:2,o,0:2]) - - def test_slice_fancy_indexing_dim_inject_none(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() - # dim injection from None - helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,c,d,e], lambda x: x[None,j,k,o,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,d,None], lambda x: x[i,j,k,o,None]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,None,d,e], lambda x: x[i,j,None,o,p]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,c,d,None], lambda x: x[None,j,k,o,None]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p]) - - def test_slice_fancy_indexing_dim_inject_and_collapse(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa - # dim injection and collapse - helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,None,d,1], lambda x: x[1,j,None,o,1]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,2,d,None], lambda x: x[None,j,2,o,None]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,1,d,None], lambda x: x[...,1,o,None]) - - def test_slice_fancy_indexing_with_idx(self): - # indexing using idx with different dim - helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)]) - helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])]) - - def test_slice_fancy_indexing_list_indices(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() - helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1],b,c,d,:], lambda x: x[[1],j,k,o,:]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p]) - - def test_gather(self): - # indices cannot have gradient - # indices cannot be negative (torch gather) - b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) - a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) - helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) - helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1)) - helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2)) - helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) - self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) - - def test_scaled_product_attention(self): - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z)) - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m)) - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True)) - - def test_binary_crossentropy(self): - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) - -if __name__ == '__main__': - np.random.seed(1337) - unittest.main(verbosity=2) + [[1.0, 1.0, 0.0, 1.0]], + ], + ) + helper_test_op( + [(3, 4, 5, 6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1) + ) + helper_test_op([()], lambda x: x.max(), Tensor.max) + + def test_mean(self): + helper_test_op([(3, 4, 5, 6)], lambda x: x.mean()) + helper_test_op([()], lambda x: x.mean()) + + def test_mean_axis(self): + helper_test_op( + [(3, 4, 5, 6)], + lambda x: x.mean(axis=(1, 2)), + lambda x: Tensor.mean(x, axis=(1, 2)), + ) + + def test_std(self): + helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x)) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=None, correction=0), + lambda x: Tensor.std(x, correction=0), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=None, correction=5), + lambda x: Tensor.std(x, correction=5), + ) + + def test_std_axis(self): + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=0), + lambda x: Tensor.std(x, axis=0), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=2), + lambda x: Tensor.std(x, axis=2), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=[1, 2]), + lambda x: Tensor.std(x, axis=[1, 2]), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=None), + lambda x: Tensor.std(x, axis=None), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, correction=0, dim=0), + lambda x: Tensor.std(x, axis=0, correction=0), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, correction=0, dim=2), + lambda x: Tensor.std(x, axis=2, correction=0), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, correction=0, dim=[1, 2]), + lambda x: Tensor.std(x, axis=[1, 2], correction=0), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, correction=0, dim=None), + lambda x: Tensor.std(x, axis=None, correction=0), + ) + + def test_std_keepdim(self): + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=None, keepdim=True), + lambda x: Tensor.std(x, keepdim=True), + ) + helper_test_op( + [(45, 65, 85)], + lambda x: torch.std(x, dim=0, keepdim=True, correction=0), + lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0), + ) + + def test_log_softmax(self): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.LogSoftmax(dim=1)(x), + Tensor.log_softmax, + atol=1e-7, + grad_atol=1e-7, + ) + helper_test_op( + [()], + lambda x: torch.nn.LogSoftmax(dim=0)(x), + Tensor.log_softmax, + atol=1e-7, + grad_atol=1e-7, + ) + + def test_log_softmax_other_axis(self): + helper_test_op( + [(10, 10, 10)], + lambda x: x.log_softmax(0), + lambda x: x.log_softmax(0), + atol=1e-7, + grad_atol=1e-7, + ) + helper_test_op( + [(10, 10, 10)], + lambda x: x.log_softmax(1), + lambda x: x.log_softmax(1), + atol=1e-7, + grad_atol=1e-7, + ) + helper_test_op( + [(10, 10, 10)], + lambda x: x.log_softmax(2), + lambda x: x.log_softmax(2), + atol=1e-7, + grad_atol=1e-7, + ) + + def test_tanh(self): + helper_test_op( + [(45, 65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6 + ) + helper_test_op( + [(45, 65)], + lambda x: x.tanh(), + Tensor.tanh, + atol=1e-6, + grad_atol=1e-6, + a=-100, + ) + helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) + + def test_hardtanh(self): + for val in range(10, 30, 5): + helper_test_op( + [(45, 65)], + lambda x: torch.nn.functional.hardtanh(x, -val, val), + lambda x: x.hardtanh(-val, val), + atol=1e-6, + grad_atol=1e-6, + ) + helper_test_op( + [()], + lambda x: torch.nn.functional.hardtanh(x, -val, val), + lambda x: x.hardtanh(-val, val), + atol=1e-6, + grad_atol=1e-6, + ) + + def test_topo_sort(self): + helper_test_op( + [(45, 65)], + lambda x: (x + x) * x, + lambda x: x.add(x).mul(x), + atol=1e-6, + grad_atol=1e-6, + ) + helper_test_op( + [()], + lambda x: (x + x) * x, + lambda x: x.add(x).mul(x), + atol=1e-6, + grad_atol=1e-6, + ) + + def test_scalar_mul(self): + helper_test_op([(45, 65)], lambda x: x * 2, lambda x: x * 2) + helper_test_op([()], lambda x: x * 2, lambda x: x * 2) + + def test_scalar_rmul(self): + helper_test_op([(45, 65)], lambda x: 2 * x, lambda x: 2 * x) + helper_test_op([()], lambda x: 2 * x, lambda x: 2 * x) + + def test_scalar_sub(self): + helper_test_op([(45, 65)], lambda x: x - 2, lambda x: x - 2) + helper_test_op([()], lambda x: x - 2, lambda x: x - 2) + + def test_scalar_rsub(self): + helper_test_op([(45, 65)], lambda x: 2 - x, lambda x: 2 - x) + helper_test_op([()], lambda x: 2 - x, lambda x: 2 - x) + + def test_flip_eye_crash(self): + helper_test_op( + [], + lambda: (torch.eye(10) @ torch.eye(10).flip(0)), + lambda: (Tensor.eye(10) @ Tensor.eye(10).flip(0)), + forward_only=True, + ) + + @unittest.skipIf( + Device.DEFAULT == "WEBGPU", + "this test uses more than 8 bufs passing the WEBGPU limit", + ) # TODO: remove after #1461 + def test_broadcast_full(self): + for torch_op, tinygrad_op in [ + (torch.add, Tensor.add), + (torch.sub, Tensor.sub), + (torch.mul, Tensor.mul), + (torch.div, Tensor.div), + ]: # , (torch.pow, Tensor.pow)]: + for shapes in [ + ((5, 13, 24, 16), (5, 1, 24, 1)), + ((1, 3, 1, 7, 1), (2, 1, 5, 1, 8)), + ]: + with self.subTest(op=torch_op.__name__, shapes=shapes): + helper_test_op( + shapes, + torch_op, + tinygrad_op, + a=-0.5 if tinygrad_op != Tensor.pow else 0.0, + ) + + def test_broadcast_simple(self): + helper_test_op([(45, 65), (45, 1)], lambda x, y: x / y, lambda x, y: x / y) + helper_test_op([(45, 65), ()], lambda x, y: x / y, lambda x, y: x / y) + + @unittest.skipIf( + Device.DEFAULT == "WEBGPU", + "this test uses more than 8 bufs passing the WEBGPU limit", + ) # TODO: remove after #1461 + def test_broadcast_partial(self): + for torch_op, tinygrad_op in [ + (torch.add, Tensor.add), + (torch.sub, Tensor.sub), + (torch.mul, Tensor.mul), + (torch.div, Tensor.div), + ]: # , (torch.pow, Tensor.pow)]: + for shapes in [ + ((1, 32, 32, 32), (1, 32, 1, 1)), + ((5, 13, 24, 16, 2), (1, 13, 24, 1, 1)), + ((4, 1), (4, 5)), + ((1, 4), (5, 4)), + ]: + with self.subTest(op=torch_op.__name__, shapes=shapes): + # NOTE: ANE backwards? + helper_test_op( + shapes, + torch_op, + tinygrad_op, + a=-0.5 if tinygrad_op != Tensor.pow else 0.0, + ) + + def test_slice_in_bounds_1dim(self): + helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3]) + helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2]) + helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2]) + + def test_slice_on_0dim_tensor(self): + helper_test_op([()], lambda x: x[None], lambda x: x[None]) + + with self.assertRaises(IndexError): + a = Tensor(3.14) + a[0] + + def test_slice_int_indexing(self): + helper_test_op([(3)], lambda x: x[1], lambda x: x[1]) + helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2]) + helper_test_op([(10, 10)], lambda x: x[1], lambda x: x[1]) + helper_test_op([(3, 3, 3)], lambda x: x[1, 1, 1], lambda x: x[1, 1, 1]) + + def test_slice_in_bounds_multidim(self): + helper_test_op([(3, 3, 3)], lambda x: x[1:2], lambda x: x[1:2]) + helper_test_op([(3, 3, 3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2]) + helper_test_op([(3, 3, 3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) + helper_test_op( + [(3, 3, 3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1] + ) + + def test_slice_with_none(self): + helper_test_op([(3, 3, 3)], lambda x: x[None], lambda x: x[None]) + helper_test_op([(3, 3, 3)], lambda x: x[1:2, None], lambda x: x[1:2, None]) + helper_test_op( + [(3, 3, 3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2] + ) + helper_test_op( + [(3, 3, 3)], + lambda x: x[1:2, 1:2, None, -1], + lambda x: x[1:2, 1:2, None, -1], + ) + + def test_slice_one_endpoint_out_of_bounds(self): + helper_test_op([(3, 3, 3)], lambda x: x[0:4], lambda x: x[0:4]) + helper_test_op([(3, 3, 3)], lambda x: x[-6:4], lambda x: x[-6:4]) + helper_test_op([(3, 3, 3)], lambda x: x[1:50], lambda x: x[1:50]) + helper_test_op( + [(3, 3, 3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1] + ) + + def test_slice_stride_gt_one(self): + helper_test_op( + [(7, 5, 10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4] + ) + helper_test_op( + [(7, 5, 10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4] + ) + helper_test_op( + [(7, 5, 10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4] + ) + helper_test_op( + [(7, 5, 10)], + lambda x: x[1:5:2, None, None, 3, None, ::4], + lambda x: x[1:5:2, None, None, 3, None, ::4], + ) + + def test_slice_negative_strides(self): + # Torch doesn't support slicing with negative steps + a = np.random.randn(10, 10, 10).astype(np.float32) + t = Tensor(a) + np.testing.assert_allclose(a[::-1], t[::-1].numpy()) + np.testing.assert_allclose(a[::-2], t[::-2].numpy()) + np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) + np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) + np.testing.assert_allclose( + a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy() + ) + if Device.DEFAULT != "CPU": + # broken + np.testing.assert_allclose( + a[2:5:-1, :, :], t[2:5:-1, :, :].numpy() + ) # shape = (0, 10, 10) + np.testing.assert_allclose( + a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy() + ) # shape = (0, 10, 10) + np.testing.assert_allclose( + a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy() + ) # shape = (0, 10, 10) + + def test_slice_both_endpoints_out_of_bounds(self): + helper_test_op( + [(3, 3, 3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True + ) + helper_test_op( + [(3, 3, 3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True + ) + + def test_slice_start_gt_end(self): + helper_test_op( + [(3, 3, 3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True + ) + helper_test_op( + [(3, 3, 3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True + ) + + def test_slice_empty(self): + helper_test_op( + [(10, 10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True + ) + + def test_slice_zero_in_shape(self): + helper_test_op( + [(10, 10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True + ) # x.shape = (0, 10) + helper_test_op( + [(3, 3, 3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True + ) # x.shape = (0, 3, 3) + + def test_slice_errors(self): + a = Tensor.ones(4, 3) + with self.assertRaises(IndexError): + a[ + 1, 77, 77, 77 + ] # IndexError: (finds too many indices before the out of bounds) + a[1, 77] # IndexError: (out of bounds). + a[0, -77] + a[..., ...] # IndexError: only single ellipsis + + def test_slice_ellipsis(self): + helper_test_op([(3, 3, 3, 3)], lambda x: x[..., 0], lambda x: x[..., 0]) + helper_test_op([(3, 3, 3, 3)], lambda x: x[0, ...], lambda x: x[0, ...]) + helper_test_op([(3, 3, 3, 3)], lambda x: x[0, ..., 0], lambda x: x[0, ..., 0]) + helper_test_op( + [(3, 3, 3, 3)], lambda x: x[0:3, ..., 2:3], lambda x: x[0:3, ..., 2:3] + ) + helper_test_op( + [(3, 3, 3, 3)], + lambda x: x[None, 0:3, ..., 0, None], + lambda x: x[None, 0:3, ..., 0, None], + ) + + def test_pad2d(self): + helper_test_op( + [(3, 3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4)), + lambda x: x.pad2d(padding=(1, 2, 3, 4)), + ) + helper_test_op( + [(3, 3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (-1, 2, -3, 4)), + lambda x: x.pad2d(padding=(-1, 2, -3, 4)), + ) + helper_test_op( + [(3, 3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4), value=5), + lambda x: x.pad2d(padding=(1, 2, 3, 4), value=5), + ) + helper_test_op( + [(3, 3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (-1, 2, -3, 4), value=5), + lambda x: x.pad2d(padding=(-1, 2, -3, 4), value=5), + ) + + def test_pad(self): + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4)), + lambda x: x.pad(((3, 4), (1, 2))), + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4), value=5), + lambda x: x.pad(((3, 4), (1, 2)), value=5), + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4), value=float("inf")), + lambda x: x.pad(((3, 4), (1, 2)), value=float("inf")), + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 2, 3, 4), value=float("-inf")), + lambda x: x.pad(((3, 4), (1, 2)), value=float("-inf")), + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (0, 0, 3, 4), value=1), + lambda x: x.pad(((3, 4), None), value=1), + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (0, 0, 0, 0), value=1), + lambda x: x.pad((None, None), value=1), + ) + + def test_pad_slice(self): + for value in 0.0, 3.456: + helper_test_op( + [(1)], + lambda x: torch.nn.functional.pad(x, (1, 0), value=value)[0], + lambda x: x.pad(((1, 0),), value=value)[0], + ) + helper_test_op( + [(4)], + lambda x: torch.nn.functional.pad(x, (1, 0), value=value)[0], + lambda x: x.pad(((1, 0),), value=value)[0], + ) + helper_test_op( + [(4)], + lambda x: torch.nn.functional.pad(x, (3, 0), value=value)[0:1], + lambda x: x.pad(((3, 0),), value=value)[0:1], + ) + helper_test_op( + [(4)], + lambda x: torch.nn.functional.pad(x, (0, 3), value=value)[6], + lambda x: x.pad(((0, 3),), value=value)[6], + ) + helper_test_op( + [(4)], + lambda x: torch.nn.functional.pad(x, (0, 3), value=value)[4:6], + lambda x: x.pad(((0, 3),), value=value)[4:6], + ) + helper_test_op( + [(5, 5)], + lambda x: torch.nn.functional.pad(x, (0, 0, 1, 0), value=value)[0], + lambda x: x.pad(((1, 0), (0, 0)), value=value)[0], + ) + helper_test_op( + [(2, 2)], + lambda x: torch.nn.functional.pad(x, (0, 1, 0, 0), value=value)[0, 2], + lambda x: x.pad(((0, 0), (0, 1)), value=value)[0, 2], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (0, 0, 1, 0), value=value)[0, 2], + lambda x: x.pad(((1, 0), (0, 0)), value=value)[0, 2], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (0, 0, 0, 2), value=value)[5], + lambda x: x.pad(((0, 2), (0, 0)), value=value)[5], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (0, 0, 0, 2), value=value)[3:5], + lambda x: x.pad(((0, 2), (0, 0)), value=value)[3:5], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (3, 0, 0, 0), value=value)[1, 0], + lambda x: x.pad(((0, 0), (3, 0)), value=value)[1, 0], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (3, 0, 0, 0), value=value)[1, 0:4], + lambda x: x.pad(((0, 0), (3, 0)), value=value)[1, 0:4], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (3, 4, 1, 2), value=value)[0], + lambda x: x.pad(((1, 2), (3, 4)), value=value)[0], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (3, 4, 1, 2), value=value)[:, 1], + lambda x: x.pad(((1, 2), (3, 4)), value=value)[:, 1], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.nn.functional.pad(x, (3, 4, 1, 2), value=value)[:, 4], + lambda x: x.pad(((1, 2), (3, 4)), value=value)[:, 4], + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (0, 3, 0, 0), value=value)[:, 4:6], + lambda x: x.pad(((0, 0), (0, 3)), value=value)[:, 4:6], + ) + helper_test_op( + [(3, 3)], + lambda x: torch.nn.functional.pad(x, (0, 1, 3, 2), value=value)[0:2, :], + lambda x: x.pad(((3, 2), (0, 1)), value=value)[0:2, :], + ) + helper_test_op( + [(3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 1, 0, 1, 3, 2), value=value)[ + 0:2, :, : + ], + lambda x: x.pad(((3, 2), (0, 1), (1, 1)), value=value)[0:2, :, :], + ) + helper_test_op( + [(3, 3, 3)], + lambda x: torch.nn.functional.pad(x, (1, 1, 0, 1, 3, 2), value=value)[ + 2:4, :, : + ], + lambda x: x.pad(((3, 2), (0, 1), (1, 1)), value=value)[2:4, :, :], + ) + + def test_stack_slice(self): + helper_test_op( + [(4)], + lambda x: torch.stack([x for i in range(3)])[0, :], + lambda x: Tensor.stack([x for i in range(3)])[0, :], + ) + helper_test_op( + [(5)], + lambda x: torch.stack([x for i in range(3)])[0, 0], + lambda x: Tensor.stack([x for i in range(3)])[0, 0], + ) + helper_test_op( + [(4, 4)], + lambda x: torch.stack([x for i in range(4)])[3], + lambda x: Tensor.stack([x for i in range(4)])[3], + ) + + def test_transpose(self): + helper_test_op( + [(3, 3, 3)], lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2) + ) + helper_test_op( + [(3, 3, 3)], lambda x: x.transpose(0, 2), lambda x: x.transpose(0, 2) + ) + helper_test_op( + [(1, 2, 3, 4)], + lambda x: x.movedim((3, 0, 2, 1), (0, 1, 2, 3)), + lambda x: x.permute(order=(3, 0, 2, 1)), + ) + helper_test_op( + [(3, 4, 5, 6)], + lambda x: x.movedim((3, 2, 1, 0), (0, 1, 2, 3)), + lambda x: x.permute(order=(3, 2, 1, 0)), + ) + helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(())) + + def test_reshape(self): + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.reshape(x, (-1, 3, 6, 6)), + lambda x: x.reshape(shape=(-1, 3, 6, 6)), + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.reshape(x, (-1, 1, 6, 6)), + lambda x: x.reshape(shape=(-1, 1, 6, 6)), + ) + helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) + helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) + helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1])) + + with self.assertRaises(ValueError): + x = Tensor.ones((4, 3, 6, 6)) + x.reshape([]) + + def test_flip(self): + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)) + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.flip(x, (0, 1)), + lambda x: x.flip(axis=(0, 1)), + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.flip(x, (0, 1, 3)), + lambda x: x.flip(axis=(0, 1, 3)), + ) + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)) + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.flip(x, (0, 1, 3)).flip((0,)), + lambda x: x.flip(axis=(0, 1, 3)).flip(0), + ) + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,)) + ) + helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()) + ) + + def test_squeeze(self): + helper_test_op( + [(1, 3, 6, 6)], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0) + ) + helper_test_op( + [(4, 3, 1, 6)], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1) + ) + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.squeeze(x, 3), lambda x: x.squeeze(dim=3) + ) + self.helper_test_exception( + [(4, 3, 6, 6)], + lambda x: torch.squeeze(x, 50), + lambda x: x.squeeze(dim=50), + expected=IndexError, + exact=True, + ) + self.helper_test_exception( + [(4, 3, 6, 6)], + lambda x: torch.squeeze(x, -50), + lambda x: x.squeeze(dim=-50), + expected=IndexError, + exact=True, + ) + helper_test_op( + [(4, 3, 6, 1)], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1) + ) + helper_test_op( + [(4, 3, 6, 6)], lambda x: torch.squeeze(x), lambda x: x.squeeze() + ) + helper_test_op( + [(1, 3, 6, 6)], lambda x: torch.squeeze(x), lambda x: x.squeeze() + ) + helper_test_op([(2, 3, 1)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) + helper_test_op( + [()], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1) + ) + helper_test_op([()], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) + self.helper_test_exception( + [()], + lambda x: torch.squeeze(x, 10), + lambda x: x.squeeze(dim=10), + expected=IndexError, + exact=True, + ) + helper_test_op([()], lambda x: torch.squeeze(x), lambda x: x.squeeze()) + + def test_unsqueeze(self): + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.unsqueeze(x, 0), + lambda x: x.unsqueeze(dim=0), + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.unsqueeze(x, 4), + lambda x: x.unsqueeze(dim=4), + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.unsqueeze(x, -1), + lambda x: x.unsqueeze(dim=-1), + ) + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.unsqueeze(x, -3), + lambda x: x.unsqueeze(dim=-3), + ) + helper_test_op( + [()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0) + ) + + def test_flatten(self): + for axis in range(3): + helper_test_op( + [(4, 3, 6, 6)], + lambda x: torch.flatten(x, start_dim=axis), + lambda x: x.flatten(axis), + ) + helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten()) + helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten()) + + def test_detach(self): + helper_test_op( + [(4, 3, 6, 6)], + lambda x: x.detach(), + lambda x: x.detach(), + forward_only=True, + ) + helper_test_op( + [()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True + ) + + def test_expand(self): + arg = (4, 3, 2, 6) + helper_test_op( + [(4, 3, 1, 6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg) + ) + helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[])) + + @unittest.skip("very slow") + def test_sd_big_conv(self): + # internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int + helper_test_op( + [(1, 256, 64, 64), (512, 256, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w), + lambda x, w: x.conv2d(w), + atol=1e-2, + ) + + @unittest.skip("slow") + def test_large_bs_conv(self): + # large batch size can cause OpenCL image to exceed max image height on macOS + # (or cause the conv kernel to overflow short sampling coords) + helper_test_op( + [(4096, 3, 3, 3), (1, 3, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w), + lambda x, w: x.conv2d(w), + atol=1e-4, + rtol=1e-2, + ) + + @unittest.skip("slow") + def test_large_ic_conv(self): + # large input channel count can cause OpenCL image to exceed max image width on macOS + helper_test_op( + [(1, 2048, 3, 3), (1, 2048, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w), + lambda x, w: x.conv2d(w), + atol=1e-4, + ) + + def test_biased_conv2d(self): + C = 8 + helper_test_op( + [(1, C, 5, 5), (C, C, 1, 1), (C,)], + lambda x, w, b: torch.nn.functional.conv2d( + torch.nn.functional.conv2d(x, w, b).relu(), w, b + ), + lambda x, w, b: Tensor.conv2d(x, w, b).relu().conv2d(w, b), + atol=1e-4, + ) + + def test_simple_conv2d(self): + helper_test_op( + [(1, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + @unittest.skipIf(IMAGE > 0, "no conv3d on images") + def test_simple_conv3d(self): + helper_test_op( + [(1, 4, 9, 9, 9), (4, 4, 3, 3, 3)], + lambda x, w: torch.nn.functional.conv3d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + @unittest.skipIf(IMAGE > 0, "no conv3d on images") + def test_padded_conv3d(self): + helper_test_op( + [(1, 4, 9, 9, 9), (4, 4, 3, 3, 3)], + lambda x, w: torch.nn.functional.conv3d(x, w, padding=1).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=[1, 1, 1, 1, 1, 1]).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_simple_conv2d_m4(self): + helper_test_op( + [(1, 16, 18, 18), (16, 16, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_simple_conv2d_1x1(self): + helper_test_op( + [(1, 4, 9, 9), (4, 4, 1, 1)], + lambda x, w: torch.nn.functional.conv2d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_simple_conv2d_1x1_m4(self): + helper_test_op( + [(1, 16, 32, 32), (16, 16, 1, 1)], + lambda x, w: torch.nn.functional.conv2d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_nested_conv2d(self): + helper_test_op( + [(1, 32, 9, 9), (32, 32, 3, 3), (32, 32, 3, 3)], + lambda x, w1, w2: torch.nn.functional.conv2d( + torch.nn.functional.conv2d(x, w1).relu(), w2 + ).relu(), + lambda x, w1, w2: x.conv2d(w1).relu().conv2d(w2).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + # expect reduce nodes == 3 + def test_simple_conv2d_nhwc(self): + # weights (from tf): filter_height x filter_width x in_channels x out_channels + helper_test_op( + [(2, 9, 9, 10), (3, 3, 10, 20)], + lambda x, w: torch.nn.functional.conv2d( + x.permute(0, 3, 1, 2), w.permute(3, 2, 0, 1) + ).relu(), + lambda x, w: Tensor.conv2d( + x.permute(0, 3, 1, 2), w.permute(3, 2, 0, 1) + ).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_simple_conv2d_batched(self): + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w).relu(), + lambda x, w: Tensor.conv2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + # conv transpose + + def test_simple_conv_transpose2d(self): + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose2d(x, w).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_bias_conv_transpose2d(self): + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3), (4,)], + lambda x, w, b: torch.nn.functional.conv_transpose2d(x, w, b).relu(), + lambda x, w, b: Tensor.conv_transpose2d(x, w, b).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_grouped_conv_transpose2d(self): + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose2d(x, w, groups=2).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w, groups=2).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_padded_conv_transpose2d(self): + for padding in [(1, 2), (2, 1), 2, 1, 0]: + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose2d( + x, w, padding=padding + ).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w, padding=padding).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_dilated_conv_transpose2d(self): + for dilation in [(1, 2), (2, 1), 2, 1]: + helper_test_op( + [(2, 4, 9, 9), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose2d( + x, w, dilation=dilation + ).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w, dilation=dilation).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_strided_conv_transpose2d(self): + for stride in [(2, 1), (1, 2), 1]: + helper_test_op( + [(2, 4, 4, 5), (4, 4, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose2d( + x, w, stride=stride + ).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w, stride=stride).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_output_padded_conv_transpose2d(self): + for output_padding, stride in [((1, 1), (2, 3)), ((2, 1), (3, 2))]: + helper_test_op( + [(2, 4, 6, 5), (4, 4, 3, 3), (4,)], + lambda x, w, b: torch.nn.functional.conv_transpose2d( + x, w, b, output_padding=output_padding, stride=stride + ).relu(), + lambda x, w, b: Tensor.conv_transpose2d( + x, w, b, output_padding=output_padding, stride=stride + ).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + @unittest.skipIf(IMAGE > 0, "no conv3d on images") + def test_simple_conv_transpose3d(self): + helper_test_op( + [(2, 4, 9, 9, 9), (4, 4, 3, 3, 3)], + lambda x, w: torch.nn.functional.conv_transpose3d(x, w).relu(), + lambda x, w: Tensor.conv_transpose2d(x, w).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + @unittest.skipIf((IMAGE > 0), "no conv1d on images") + def test_conv1d(self): + for bs in [1, 8]: + for cin in [1, 3]: + for H in [1, 2, 5]: + for groups in [1, 3] if cin == 3 and H == 5 else [1]: + with self.subTest( + batch_size=bs, channels=cin, groups=groups, height=H + ): + helper_test_op( + [(bs, cin, 11), (6, cin // groups, H)], + lambda x, w: torch.nn.functional.conv1d( + x, w, groups=groups + ).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + @unittest.skipIf(IMAGE > 0, "no conv1d on images") + def test_simple_padding_conv1d(self): + bs = 6 + cin = 2 + groups = 1 + H = 5 + p = (1, 1) + helper_test_op( + [(bs, cin, 11), (6, cin // groups, H)], + lambda x, w: torch.nn.functional.conv1d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + ) + + @unittest.skipIf(IMAGE > 0, "no conv1d on images") + def test_strided_conv1d_simple(self): + bs, H = 2, 3 + helper_test_op( + [(bs, 1, 5), (1, 1, H)], + lambda x, w: torch.nn.functional.conv1d(x, w, stride=2).relu(), + lambda x, w: Tensor.conv2d(x, w, stride=2).relu(), + atol=1e-4, + ) + + @unittest.skipIf(IMAGE > 0, "no conv1d on images") + def test_asymmetric_padding_conv1d(self): + for p in [(0, 1), (2, 1), (2, 0)]: + with self.subTest(p): + for n in [3, 4]: + for k in [2]: + helper_test_op( + [(1, 1, n), (1, 1, k)], + lambda x, w: torch.nn.functional.conv1d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + ) + helper_test_op( + [(1, 1, n), (1, 1, k)], + lambda x, w: torch.nn.functional.conv1d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + ) + + def _test_conv2d(self, bs=1, cin=1): + for H in [1, 2, 3]: + for W in [1, 2, 3, 5]: + for groups in [1, 3] if cin == 3 and H == 3 and W == 3 else [1]: + with self.subTest( + batch_size=bs, channels=cin, groups=groups, height=H, width=W + ): + helper_test_op( + [(bs, cin, 11, 7), (6, cin // groups, H, W)], + lambda x, w: torch.nn.functional.conv2d( + x, w, groups=groups + ).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_conv2d(self): + self._test_conv2d(bs=1, cin=3) + + def test_conv2d_bs_4_cin_3(self): + self._test_conv2d(bs=4, cin=3) + + def test_conv2d_bs_1_cin_1(self): + self._test_conv2d(bs=1, cin=1) + + def test_conv2d_bs_4_cin_1(self): + self._test_conv2d(bs=4, cin=1) + + def test_large_input_conv2d(self): + bs = 4 + cin = 16 + groups = 1 + H = 5 + W = 2 + helper_test_op( + [(bs, cin, 64, 64), (6, cin // groups, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + # needed to relax tolerance on NVIDIA + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-3, + grad_rtol=1e-5, + ) + + def test_simple_grouped_conv2d(self): + bs = 1 + groups = 2 + rcout = 1 + cin = 2 + helper_test_op( + [(bs, groups * cin, 1, 1), (groups * rcout, cin, 1, 1)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_medium_grouped_conv2d(self): + bs = 1 + groups = 2 + rcout = 2 + cin = 2 + helper_test_op( + [(bs, groups * cin, 1, 1), (groups * rcout, cin, 1, 1)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_depthwise_conv2d(self): + bs = 1 + groups = 32 + rcout = 1 + cin = 1 + helper_test_op( + [(bs, groups * cin, 32, 32), (groups * rcout, cin, 1, 1)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_grouped_conv2d(self): + bs = 4 + groups = 5 + rcout = 7 + cin = 3 + helper_test_op( + [(bs, groups * cin, 5, 5), (groups * rcout, cin, 3, 3)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_fancy_conv2d(self): + bs = 2 + cin = 3 + cout = 1 + groups = 3 + H, W = 3, 3 + helper_test_op( + [(bs, cin, 11, 28), (groups * cout, cin // groups, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, groups=groups).relu(), + lambda x, w: Tensor.conv2d(x, w, groups=groups).relu(), + atol=1e-4, + grad_rtol=1e-5, + ) + + def test_strided_conv2d_simple(self): + bs, H, W = 2, 3, 1 + helper_test_op( + [(bs, 1, 5, 1), (1, 1, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, stride=2).relu(), + lambda x, w: Tensor.conv2d(x, w, stride=2).relu(), + atol=1e-4, + ) + + def test_strided_conv2d(self): + bs = 4 + cin = 3 + H, W = 3, 3 + with self.subTest(stride := 2): + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, stride=2).relu(), + lambda x, w: Tensor.conv2d(x, w, stride=stride).relu(), + atol=1e-4, + ) + with self.subTest(stride := (2, 1)): + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, stride=stride).relu(), + lambda x, w: Tensor.conv2d(x, w, stride=(2, 1)).relu(), + atol=1e-4, + ) + + def test_negative_padding_conv2d(self): + n, k = 10, 3 + helper_test_op( + [(1, 1, n, n), (1, 1, k, k)], + lambda x, w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1], w).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=-1).relu(), + atol=1e-4, + ) + helper_test_op( + [(1, 1, n, n), (1, 1, k, k)], + lambda x, w: torch.nn.functional.conv2d(x[:, :, 1:, 1:], w).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=(-1, 0, -1, 0)).relu(), + atol=1e-4, + ) + + def test_simple_padding_conv2d(self): + p = (1, 1, 1, 1) + helper_test_op( + None, + lambda x, w: torch.nn.functional.conv2d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + vals=[[[[[2.0, 3.0]]]], [[[[1.0]]]]], + ) + + def test_asymmetric_padding_conv2d(self): + for p in [(0, 1, 0, 1), (2, 1, 2, 1), (2, 0, 2, 1)]: + with self.subTest(p): + for n in [3, 4]: + for k in [2]: + helper_test_op( + [(1, 1, n, n), (1, 1, k, k)], + lambda x, w: torch.nn.functional.conv2d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + ) + helper_test_op( + [(1, 1, n, n), (1, 1, k, k)], + lambda x, w: torch.nn.functional.conv2d( + torch.nn.functional.pad(x, p), w + ).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=p).relu(), + atol=1e-4, + ) + + def test_padded_conv2d_p21(self): + bs, cin, H, W, padding = 4, 3, 3, 3, (2, 1) + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, padding=padding).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=padding).relu(), + atol=1e-4, + ) + + def test_padded_conv2d_p22(self): + bs, cin, H, W, padding = 4, 3, 3, 3, (2, 2) + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, padding=padding).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=padding).relu(), + atol=1e-4, + ) + + def test_padded_conv2d_1x1(self): + bs, cin, H, W, padding = 4, 3, 1, 1, 2 + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, padding=padding).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=padding).relu(), + atol=1e-4, + ) + + def test_padded_conv2d_bs1(self): + bs, cin, H, W, padding = 1, 3, 3, 3, 1 + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d(x, w, padding=padding).relu(), + lambda x, w: Tensor.conv2d(x, w, padding=padding).relu(), + atol=1e-4, + ) + + def test_padding_add(self): + helper_test_op( + [(64, 64), (60, 60)], + lambda x, w: x + torch.nn.functional.pad(w, (2, 2, 2, 2)), + lambda x, w: x + w.pad2d((2, 2, 2, 2)), + ) + + def test_dilated_conv2d(self): + bs = 4 + cin = 3 + H, W = 3, 3 + for d in [2, (2, 1)]: + with self.subTest(dilation := d): + helper_test_op( + [(bs, cin, 11, 28), (4, cin, H, W)], + lambda x, w: torch.nn.functional.conv2d( + x, w, dilation=dilation + ).relu(), + lambda x, w: Tensor.conv2d(x, w, dilation=dilation).relu(), + atol=1e-4, + ) + + def test_maxpool2d_simple(self): + ksz = (2, 2) + helper_test_op( + [(1, 1, 2, 3)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), + lambda x: Tensor.max_pool2d(x, kernel_size=ksz), + ) + + def test_maxpool2d(self): + for ksz in [(2, 2), (3, 3), 2, 3, (3, 2), (5, 5), (5, 1)]: + with self.subTest(kernel_size=ksz): + helper_test_op( + [(32, 2, 110, 28)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), + lambda x: Tensor.max_pool2d(x, kernel_size=ksz), + ) + + def test_maxpool2d_bigger_stride(self): + for stride in [(2, 3), (3, 2), 2, 3]: + with self.subTest(stride=stride): + helper_test_op( + [(32, 2, 110, 28)], + lambda x: torch.nn.functional.max_pool2d( + x, kernel_size=(2, 2), stride=stride + ), + lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=stride), + ) + + @unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this") + def test_maxpool2d_unit_stride(self): + helper_test_op( + [(32, 2, 110, 28)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5, 5), stride=1), + lambda x: Tensor.max_pool2d(x, kernel_size=(5, 5), stride=1), + ) + + def test_maxpool2d_smaller_stride(self): + for stride in [(2, 3), (3, 2), 2, 3]: + with self.subTest(stride=stride): + helper_test_op( + [(32, 2, 110, 28)], + lambda x: torch.nn.functional.max_pool2d( + x, kernel_size=(5, 5), stride=stride + ), + lambda x: Tensor.max_pool2d(x, kernel_size=(5, 5), stride=stride), + ) + + def test_maxpool2d_dilation(self): + for dilation in [(2, 3), (3, 2), 2, 3]: + helper_test_op( + [(32, 2, 110, 28)], + lambda x: torch.nn.functional.max_pool2d( + x, kernel_size=(5, 5), dilation=dilation + ), + lambda x: Tensor.max_pool2d(x, kernel_size=(5, 5), dilation=dilation), + ) + + def test_avgpool2d(self): + shape = (32, 2, 111, 28) + for ksz in [(2, 2), (3, 3), (3, 2), (5, 5), (5, 1)]: + with self.subTest(kernel_size=ksz): + helper_test_op( + [shape], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), + rtol=1e-5, + ) + + def test_global_avgpool2d(self): + helper_test_op( + [(32, 2, 111, 28)], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111, 28)), + lambda x: Tensor.avg_pool2d(x, kernel_size=(111, 28)), + rtol=1e-5, + ) + + def test_cat(self): + for dim in range(-2, 3): + helper_test_op( + [(45, 65, 9), (45, 65, 9), (45, 65, 9)], + lambda x, y, z: torch.cat((x, y, z), dim), + lambda x, y, z: x.cat(y, z, dim=dim), + ) + + with self.assertRaises(AssertionError): + a = Tensor(3.14) + a.cat(a) + + def test_multicat(self): + for dim in range(-1, 2): + helper_test_op( + [(45, 65), (45, 65), (45, 65)], + lambda x, y, z: torch.cat((x, y, z), dim), + lambda x, y, z: x.cat(y, z, dim=dim), + ) + + def test_stack(self): + x = Tensor.randn(45, 65, 3) + + for dim in range(-1, 3): + helper_test_op( + [(45, 65, 3), (45, 65, 3), (45, 65, 3)], + lambda x, y, z: torch.stack((x, y, z), dim=dim), + lambda x, y, z: Tensor.stack([x, y, z], dim=dim), + ) + + with self.assertRaises(IndexError): + Tensor.stack([x], dim=77) + + a = Tensor(3.14) + np.testing.assert_allclose( + Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy() + ) + + def test_repeat(self): + x = Tensor.randn(4, 6, 3) + base_repeats = [2, 4, 3] + + for reps in [[], [4], [2, 1], [3, 2, 2]]: + repeats = base_repeats + reps + helper_test_op( + [(4, 6, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats) + ) + helper_test_op( + [()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats) + ) + + with self.assertRaises(ValueError): + x.repeat((2, 4)) + + np.testing.assert_allclose( + x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy() + ) + + def test_clip(self): + helper_test_op( + [(45, 65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2) + ) + + def test_matvecmat(self): + helper_test_op( + [(1, 128), (128, 128), (128, 128)], + lambda x, y, z: (x @ y).relu() @ z, + atol=1e-4, + ) + + def test_matvec(self): + helper_test_op([(1, 128), (128, 128)], lambda x, y: (x @ y).relu(), atol=1e-4) + + # this was the failure in llama early realizing freqs_cis + def test_double_slice(self): + helper_test_op([(4, 4)], lambda x: x[:, 1:2][1:2]) + helper_test_op([(4, 4)], lambda x: x[1:3][1:2]) + helper_test_op([(4, 4)], lambda x: x[:, 1:2][0:1]) + helper_test_op([(4, 4)], lambda x: x[:, 1:2][:, 0:1]) + + @unittest.skip("this test is broken #862") + def test_max_inf(self): + n = Tensor([1, float("nan")]).max().numpy() + assert math.isnan(n.item()), f"{n.item()} is not nan" + + def test_inf_where(self): + x = Tensor.full((3, 3), float("inf")) + n = (x < 0).where(x, 1).numpy() + assert np.all(n == 1.0) + + def _get_index_randoms(self): + # indices cannot have gradient + # TODO currently does not support IndexError for out of bounds idx values + a = torch.randint( + low=-1, + high=1, + size=(2, 1, 1, 1, 1, 1), + dtype=torch.int64, + requires_grad=False, + ) + b = torch.randint( + high=1, size=(1, 3, 1, 1, 1, 1), dtype=torch.int64, requires_grad=False + ) + c = torch.randint( + low=-5, + high=5, + size=(1, 1, 4, 1, 1, 1), + dtype=torch.int64, + requires_grad=False, + ) + d = torch.randint( + high=4, size=(2, 1, 1, 5, 1, 1), dtype=torch.int64, requires_grad=False + ) + e = torch.randint( + high=1, size=(1, 1, 1, 1, 6, 1), dtype=torch.int64, requires_grad=False + ) + i, j, k, o, p = [ + Tensor( + tor.detach().numpy().astype(np.int32), + dtype=dtypes.int32, + requires_grad=False, + ) + for tor in [a, b, c, d, e] + ] + return a, b, c, d, e, i, j, k, o, p + + def test_slice_fancy_indexing_no_dim_collapse(self): + a, b, c, d, e, i, j, k, o, p = self._get_index_randoms() + # no dim collapse from int or dim injection from None + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[a, b, c, d, e], lambda x: x[i, j, k, o, p] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[:, b, c, d, :], lambda x: x[:, j, k, o, :] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[a, b, ...], lambda x: x[i, j, ...] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[a, ..., e], lambda x: x[i, ..., p] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[..., c, :, e], lambda x: x[..., k, :, p] + ) + + def test_slice_fancy_indexing_dim_collapse_int(self): + a, b, c, d, e, i, j, k, o, p = self._get_index_randoms() + # dim collapse from int + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[1, b, c, d, e], lambda x: x[1, j, k, o, p] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[a, b, 3, d, e], lambda x: x[i, j, 3, o, p] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[1, b, 2, d, 2], lambda x: x[1, j, 2, o, 2] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], lambda x: x[a, 2, 2, 2, e], lambda x: x[i, 2, 2, 2, p] + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[1, :, 3:11:2, d, 0:2], + lambda x: x[1, :, 3:11:2, o, 0:2], + ) + + def test_slice_fancy_indexing_dim_inject_none(self): + a, b, c, d, e, i, j, k, o, p = self._get_index_randoms() + # dim injection from None + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[None, b, c, d, e], + lambda x: x[None, j, k, o, p], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[a, b, c, d, None], + lambda x: x[i, j, k, o, None], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[a, b, None, d, e], + lambda x: x[i, j, None, o, p], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[None, b, c, d, None], + lambda x: x[None, j, k, o, None], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[a, :, None, d, e], + lambda x: x[i, :, None, o, p], + ) + + def test_slice_fancy_indexing_dim_inject_and_collapse(self): + a, b, c, d, e, i, j, k, o, p = self._get_index_randoms() # noqa + # dim injection and collapse + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[1, b, None, d, 1], + lambda x: x[1, j, None, o, 1], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[None, b, 2, d, None], + lambda x: x[None, j, 2, o, None], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[..., 1, d, None], + lambda x: x[..., 1, o, None], + ) + + def test_slice_fancy_indexing_with_idx(self): + # indexing using idx with different dim + helper_test_op( + [(2, 3)], + lambda x: x[torch.tensor([[0, 0, 0], [0, 0, 0]]), torch.tensor(1)], + lambda x: x[Tensor([[0, 0, 0], [0, 0, 0]]), Tensor(1)], + ) + helper_test_op( + [(2, 3)], + lambda x: x[torch.tensor([1]), torch.tensor([[0, 0, 0], [0, 0, 0]])], + lambda x: x[Tensor([1]), Tensor([[0, 0, 0], [0, 0, 0]])], + ) + + def test_slice_fancy_indexing_list_indices(self): + a, b, c, d, e, i, j, k, o, p = self._get_index_randoms() + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[[0], b, c, d, :], + lambda x: x[[0], j, k, o, :], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[[1], b, c, d, :], + lambda x: x[[1], j, k, o, :], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[[1, 0], b, c, d, :], + lambda x: x[[1, 0], j, k, o, :], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[a, b, c, [1, 2, 3], ...], + lambda x: x[i, j, k, [1, 2, 3], ...], + ) + helper_test_op( + [(2, 5, 6, 5, 3, 4)], + lambda x: x[a, [2, 1, 0], c, [2, 1, 0], e], + lambda x: x[i, [2, 1, 0], k, [2, 1, 0], p], + ) + + def test_gather(self): + # indices cannot have gradient + # indices cannot be negative (torch gather) + b = torch.randint(3, size=[3, 4, 5], dtype=torch.int64, requires_grad=False) + a = Tensor( + b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False + ) + helper_test_op( + [(4, 5, 6)], + lambda x: x.gather(index=b, dim=0), + lambda x: x.gather(idx=a, dim=0), + ) + helper_test_op( + [(4, 5, 6)], + lambda x: x.gather(index=b, dim=1), + lambda x: x.gather(idx=a, dim=1), + ) + helper_test_op( + [(4, 5, 6)], + lambda x: x.gather(index=b, dim=2), + lambda x: x.gather(idx=a, dim=2), + ) + helper_test_op( + [(3, 4, 5)], + lambda x: x.gather(index=b, dim=0), + lambda x: x.gather(idx=a, dim=0), + ) + self.helper_test_exception( + [(4, 5, 6)], + lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), + lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), + expected=(RuntimeError, AssertionError), + ) + self.helper_test_exception( + [(2, 1, 1)], + lambda x: x.gather(index=b, dim=0), + lambda x: x.gather(idx=a, dim=0), + expected=(RuntimeError, AssertionError), + ) + + def test_scaled_product_attention(self): + helper_test_op( + [(32, 8, 16, 64), (32, 8, 16, 64), (32, 8, 16, 64)], + lambda x, y, z: torch.nn.functional.scaled_dot_product_attention(x, y, z), + lambda x, y, z: Tensor.scaled_dot_product_attention(x, y, z), + ) + helper_test_op( + [(32, 8, 16, 64), (32, 8, 16, 64), (32, 8, 16, 64), (32, 8, 16, 16)], + lambda x, y, z, m: torch.nn.functional.scaled_dot_product_attention( + x, y, z, attn_mask=m + ), + lambda x, y, z, m: Tensor.scaled_dot_product_attention( + x, y, z, attn_mask=m + ), + ) + helper_test_op( + [(32, 8, 16, 64), (32, 8, 16, 64), (32, 8, 16, 64)], + lambda x, y, z: torch.nn.functional.scaled_dot_product_attention( + x, y, z, is_causal=True + ), + lambda x, y, z: Tensor.scaled_dot_product_attention( + x, y, z, is_causal=True + ), + ) + + def test_binary_crossentropy(self): + helper_test_op( + [(32, 10), (32, 10)], + lambda x, y: torch.nn.functional.binary_cross_entropy( + x.sigmoid(), torch.clip(y, 0, 1) + ), + lambda x, y: x.sigmoid().binary_crossentropy(y.clip(0, 1)), + ) + helper_test_op( + [(32, 10), (32, 10)], + lambda x, y: torch.nn.functional.binary_cross_entropy_with_logits( + x, torch.clip(y, 0, 1) + ), + lambda x, y: x.binary_crossentropy_logits(y.clip(0, 1)), + ) + helper_test_op( + [(32, 10), (32, 10)], + lambda x, y: torch.nn.functional.binary_cross_entropy_with_logits( + x, torch.clip(y, 0, 1) + ), + lambda x, y: x.sigmoid().binary_crossentropy(y.clip(0, 1)), + ) + helper_test_op( + [(32, 10), (32, 10)], + lambda x, y: torch.nn.functional.binary_cross_entropy( + x.sigmoid(), torch.clip(y, 0, 1) + ), + lambda x, y: x.binary_crossentropy_logits(y.clip(0, 1)), + ) + + +if __name__ == "__main__": + np.random.seed(1337) + unittest.main(verbosity=2) diff --git a/test/test_optim.py b/test/test_optim.py index df1e53f36..6a363f38f 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -8,91 +8,152 @@ import pytest pytestmark = pytest.mark.exclude_cuda np.random.seed(1337) -x_init = np.random.randn(1,4).astype(np.float32) -W_init = np.random.randn(4,4).astype(np.float32) -m_init = np.random.randn(1,4).astype(np.float32) +x_init = np.random.randn(1, 4).astype(np.float32) +W_init = np.random.randn(4, 4).astype(np.float32) +m_init = np.random.randn(1, 4).astype(np.float32) + class TinyNet: - def __init__(self, tensor): - self.x = tensor(x_init.copy(), requires_grad=True) - self.W = tensor(W_init.copy(), requires_grad=True) - self.m = tensor(m_init.copy()) + def __init__(self, tensor): + self.x = tensor(x_init.copy(), requires_grad=True) + self.W = tensor(W_init.copy(), requires_grad=True) + self.m = tensor(m_init.copy()) + + def forward(self): + out = self.x.matmul(self.W).relu() + # print(out.detach().numpy()) + out = out.log_softmax(1) + out = out.mul(self.m).add(self.m).sum() + return out - def forward(self): - out = self.x.matmul(self.W).relu() - # print(out.detach().numpy()) - out = out.log_softmax(1) - out = out.mul(self.m).add(self.m).sum() - return out def step(tensor, optim, steps=1, kwargs={}): - net = TinyNet(tensor) - optim = optim([net.x, net.W], **kwargs) - for _ in range(steps): - out = net.forward() - optim.zero_grad() - out.backward() - optim.step() - return net.x.detach().numpy(), net.W.detach().numpy() + net = TinyNet(tensor) + optim = optim([net.x, net.W], **kwargs) + for _ in range(steps): + out = net.forward() + optim.zero_grad() + out.backward() + optim.step() + return net.x.detach().numpy(), net.W.detach().numpy() + class TestOptim(unittest.TestCase): + def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol): + for x, y in zip( + step(Tensor, tinygrad_optim, steps, kwargs=opts), + step(torch.tensor, torch_optim, steps, kwargs=opts), + ): + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) - def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol): - for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts), - step(torch.tensor, torch_optim, steps, kwargs=opts)): - np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) + def _test_sgd(self, steps, opts, atol, rtol): + self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) - def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) - def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) - def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol) + def _test_adam(self, steps, opts, atol, rtol): + self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) - def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0) - def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5) - def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0) - def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5) + def _test_adamw(self, steps, opts, atol, rtol): + self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol) - def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0) - def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4) - def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0) - def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4) + def test_sgd(self): + self._test_sgd(1, {"lr": 0.001}, 1e-6, 0) - def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0) - def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4) - def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0) - def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4) + def test_sgd_high_lr(self): + self._test_sgd(1, {"lr": 10}, 1e-6, 1e-5) - def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0) - def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4) - def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) - def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) + def test_sgd_wd(self): + self._test_sgd(1, {"lr": 0.001, "weight_decay": 0.1}, 1e-6, 0) - def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) - def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4) - def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0) - def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4) + def test_sgd_high_lr_wd(self): + self._test_sgd(1, {"lr": 10, "weight_decay": 0.1}, 1e-6, 1e-5) - def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0) - def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-4, 5e-4) + def test_multistep_sgd(self): + self._test_sgd(10, {"lr": 0.001}, 1e-6, 0) - def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0) - def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3) + def test_multistep_sgd_high_lr(self): + self._test_sgd(10, {"lr": 10}, 1e-6, 3e-4) - def test_duped_weights(self): - for Opt in [Adam, AdamW, SGD]: - losses = [] - for i in range(2): - w = Tensor(x_init.copy()) - opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1) + def test_multistep_sgd_wd(self): + self._test_sgd(10, {"lr": 0.001, "weight_decay": 0.1}, 1e-6, 0) - loss = None - for _ in range(3): - loss = w.sum() - opt.zero_grad() - loss.backward() - opt.step() - losses.append(loss.numpy()) + def test_multistep_sgd_high_lr_wd(self): + self._test_sgd(10, {"lr": 9, "weight_decay": 0.1}, 1e-6, 3e-4) - np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0) + def test_multistep_sgd_momentum(self): + self._test_sgd(10, {"lr": 0.001, "momentum": 0.9}, 1e-6, 0) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def test_multistep_sgd_high_lr_momentum(self): + self._test_sgd(10, {"lr": 10, "momentum": 0.9}, 1e-5, 3e-4) + + def test_multistep_sgd_momentum_wd(self): + self._test_sgd(10, {"lr": 0.001, "momentum": 0.9, "weight_decay": 0.1}, 1e-6, 0) + + def test_multistep_sgd_high_lr_momentum_wd(self): + self._test_sgd(10, {"lr": 10, "momentum": 0.9, "weight_decay": 0.1}, 1e-5, 3e-4) + + def test_multistep_sgd_nesterov_momentum(self): + self._test_sgd(10, {"lr": 0.001, "momentum": 0.9, "nesterov": True}, 1e-5, 0) + + def test_multistep_sgd_high_lr_nesterov_momentum(self): + self._test_sgd(10, {"lr": 10, "momentum": 0.9, "nesterov": True}, 1e-5, 3e-4) + + def test_multistep_sgd_nesterov_momentum_wd(self): + self._test_sgd( + 10, + {"lr": 0.001, "momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, + 1e-5, + 0, + ) + + def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): + self._test_sgd( + 10, + {"lr": 9, "momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, + 1e-5, + 3e-4, + ) + + def test_adam(self): + self._test_adam(1, {"lr": 0.001}, 1e-5, 0) + + def test_adam_high_lr(self): + self._test_adam(1, {"lr": 10}, 1e-4, 1e-4) + + def test_adamw(self): + self._test_adamw(1, {"lr": 0.001}, 1e-5, 0) + + def test_adamw_high_lr(self): + self._test_adamw(1, {"lr": 10}, 1e-4, 1e-4) + + def test_multistep_adam(self): + self._test_adam(10, {"lr": 0.001}, 1e-5, 0) + + def test_multistep_adam_high_lr(self): + self._test_adam(10, {"lr": 10}, 2e-4, 5e-4) + + def test_multistep_adamw(self): + self._test_adamw(10, {"lr": 0.001}, 1e-5, 0) + + def test_multistep_adamw_high_lr(self): + self._test_adamw(10, {"lr": 10}, 5e-4, 2e-3) + + def test_duped_weights(self): + for Opt in [Adam, AdamW, SGD]: + losses = [] + for i in range(2): + w = Tensor(x_init.copy()) + opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1) + + loss = None + for _ in range(3): + loss = w.sum() + opt.zero_grad() + loss.backward() + opt.step() + losses.append(loss.numpy()) + + np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_randomness.py b/test/test_randomness.py index d9b5a9079..54c6865f6 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -7,135 +7,245 @@ import tinygrad.nn as nn from tinygrad.helpers import dtypes from functools import partial + # https://gist.github.com/devries/11405101 def ksprob(a): - fac, total, termbf = 2.0, 0.0, 0.0 - a2 = -2.0 * a * a - for j in range(1, 101): - term = fac * math.exp(a2 * j * j) - total += term - if math.fabs(term) <= 0.001 * termbf or math.fabs(term) <= 1e-8 * total: - return total - fac = -fac - termbf = math.fabs(term) - return 1.0 + fac, total, termbf = 2.0, 0.0, 0.0 + a2 = -2.0 * a * a + for j in range(1, 101): + term = fac * math.exp(a2 * j * j) + total += term + if math.fabs(term) <= 0.001 * termbf or math.fabs(term) <= 1e-8 * total: + return total + fac = -fac + termbf = math.fabs(term) + return 1.0 + def kstest(l1, l2): - n1, n2 = len(l1), len(l2) - l1.sort() - l2.sort() - j1, j2, d, fn1, fn2 = 0, 0, 0.0, 0.0, 0.0 - while j1 < n1 and j2 < n2: - d1, d2 = l1[j1], l2[j2] - if d1 <= d2: - fn1 = (float(j1) + 1.0) / float(n1) - j1 += 1 - if d2 <= d1: - fn2 = (float(j2) + 1.0) / float(n2) - j2 += 1 - dtemp = math.fabs(fn2 - fn1) - if dtemp > d: - d = dtemp - ne = float(n1 * n2) / float(n1 + n2) - nesq = math.sqrt(ne) - prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d) - return prob + n1, n2 = len(l1), len(l2) + l1.sort() + l2.sort() + j1, j2, d, fn1, fn2 = 0, 0, 0.0, 0.0, 0.0 + while j1 < n1 and j2 < n2: + d1, d2 = l1[j1], l2[j2] + if d1 <= d2: + fn1 = (float(j1) + 1.0) / float(n1) + j1 += 1 + if d2 <= d1: + fn2 = (float(j2) + 1.0) / float(n2) + j2 += 1 + dtemp = math.fabs(fn2 - fn1) + if dtemp > d: + d = dtemp + ne = float(n1 * n2) / float(n1 + n2) + nesq = math.sqrt(ne) + prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d) + return prob -def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.05): - Tensor.manual_seed(1337) - torch.manual_seed(1337) - np.random.seed(1337) - assert not (torch_func is None and numpy_func is None), "no function to compare with" - x = tiny_func(*shape).numpy().flatten() - if numpy_func is not None: y = numpy_func(shape).flatten() - if torch_func is not None: z = torch_func(shape).numpy().flatten() - return (numpy_func is None or kstest(x, y) >= alpha) and (torch_func is None or kstest(x, z) >= alpha) -def normal_test(func, shape=(20, 23), alpha=0.05): return equal_distribution(func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha) +def equal_distribution( + tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.05 +): + Tensor.manual_seed(1337) + torch.manual_seed(1337) + np.random.seed(1337) + assert not ( + torch_func is None and numpy_func is None + ), "no function to compare with" + x = tiny_func(*shape).numpy().flatten() + if numpy_func is not None: + y = numpy_func(shape).flatten() + if torch_func is not None: + z = torch_func(shape).numpy().flatten() + return (numpy_func is None or kstest(x, y) >= alpha) and ( + torch_func is None or kstest(x, z) >= alpha + ) + + +def normal_test(func, shape=(20, 23), alpha=0.05): + return equal_distribution( + func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha + ) + class TestRandomness(unittest.TestCase): - def test_rand(self): - self.assertFalse(normal_test(Tensor.rand)) - self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x))) + def test_rand(self): + self.assertFalse(normal_test(Tensor.rand)) + self.assertTrue( + equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)) + ) - def test_randn(self): - self.assertTrue(normal_test(Tensor.randn)) - self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x))) + def test_randn(self): + self.assertTrue(normal_test(Tensor.randn)) + self.assertTrue( + equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)) + ) - def test_normal(self): - self.assertTrue(normal_test(Tensor.normal)) - self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1), lambda x: np.random.normal(loc=0, scale=1, size=x))) + def test_normal(self): + self.assertTrue(normal_test(Tensor.normal)) + self.assertTrue( + equal_distribution( + Tensor.normal, + lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1), + lambda x: np.random.normal(loc=0, scale=1, size=x), + ) + ) - def test_uniform(self): - self.assertFalse(normal_test(Tensor.uniform)) - self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x)), lambda x: np.random.uniform(size=x))) - self.assertTrue(equal_distribution(partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32), numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x))) + def test_uniform(self): + self.assertFalse(normal_test(Tensor.uniform)) + self.assertTrue( + equal_distribution( + Tensor.uniform, + lambda x: torch.nn.init.uniform_(torch.empty(x)), + lambda x: np.random.uniform(size=x), + ) + ) + self.assertTrue( + equal_distribution( + partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32), + numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x), + ) + ) - def test_scaled_uniform(self): - self.assertFalse(normal_test(Tensor.scaled_uniform)) - self.assertTrue(equal_distribution(Tensor.scaled_uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1) / math.sqrt(math.prod(x)), lambda x: np.random.uniform(-1, 1, size=x) / math.sqrt(math.prod(x)))) + def test_scaled_uniform(self): + self.assertFalse(normal_test(Tensor.scaled_uniform)) + self.assertTrue( + equal_distribution( + Tensor.scaled_uniform, + lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1) + / math.sqrt(math.prod(x)), + lambda x: np.random.uniform(-1, 1, size=x) / math.sqrt(math.prod(x)), + ) + ) - def test_glorot_uniform(self): - self.assertFalse(normal_test(Tensor.glorot_uniform)) - self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:]))))) + def test_glorot_uniform(self): + self.assertFalse(normal_test(Tensor.glorot_uniform)) + self.assertTrue( + equal_distribution( + Tensor.glorot_uniform, + lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), + lambda x: np.random.uniform(-1, 1, size=x) + * math.sqrt(6 / (x[0] + math.prod(x[1:]))), + ) + ) - def test_kaiming_uniform(self): - Tensor.manual_seed(1337) - torch.manual_seed(1337) - np.random.seed(1337) - for shape in [(128, 64, 3, 3), (20, 24)]: - self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) + def test_kaiming_uniform(self): + Tensor.manual_seed(1337) + torch.manual_seed(1337) + np.random.seed(1337) + for shape in [(128, 64, 3, 3), (20, 24)]: + self.assertTrue( + equal_distribution( + Tensor.kaiming_uniform, + lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), + shape=shape, + ) + ) - def test_kaiming_normal(self): - Tensor.manual_seed(1337) - torch.manual_seed(1337) - np.random.seed(1337) - for shape in [(128, 64, 3, 3), (20, 24)]: - self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape)) + def test_kaiming_normal(self): + Tensor.manual_seed(1337) + torch.manual_seed(1337) + np.random.seed(1337) + for shape in [(128, 64, 3, 3), (20, 24)]: + self.assertTrue( + equal_distribution( + Tensor.kaiming_normal, + lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), + shape=shape, + ) + ) - def test_multinomial(self): - self.assertRaises(AssertionError, lambda: Tensor(2).multinomial(1, replacement=False)) - self.assertRaises(AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False)) - def _check_with_torch(w, num_samples, replacement): - tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement) - torch_res = torch.tensor(w).multinomial(num_samples, replacement=replacement) - self.assertEqual(tiny_res.shape, torch_res.shape) - if torch_res.ndim == 1: - tiny_res = tiny_res.unsqueeze(0) - torch_res = torch_res.unsqueeze(0) - for i in range(torch_res.shape[0]): - self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i])) - _check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True) - _check_with_torch(w=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row - _check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True) - # no-replacement isn't supported, unless taking only one sample - w = [0.1, 0.9] - self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False)) - tiny_samples = [Tensor(w).multinomial(1, replacement=False).numpy().item() for _ in range(1000)] - torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)] - self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples))) + def test_multinomial(self): + self.assertRaises( + AssertionError, lambda: Tensor(2).multinomial(1, replacement=False) + ) + self.assertRaises( + AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False) + ) - def test_multinomial_counterexample(self): - tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) - torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) - self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) - torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True) - self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) + def _check_with_torch(w, num_samples, replacement): + tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement) + torch_res = torch.tensor(w).multinomial( + num_samples, replacement=replacement + ) + self.assertEqual(tiny_res.shape, torch_res.shape) + if torch_res.ndim == 1: + tiny_res = tiny_res.unsqueeze(0) + torch_res = torch_res.unsqueeze(0) + for i in range(torch_res.shape[0]): + self.assertTrue( + equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]) + ) - def test_conv2d_init(self): - params = (128, 256, (3,3)) - assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach()) - assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach()) + _check_with_torch(w=[0.231, 0.0, 1.0, 0.5], num_samples=2000, replacement=True) + _check_with_torch( + w=[[0.2, 0.8]], num_samples=2000, replacement=True + ) # 2D but only 1 row + _check_with_torch( + w=[[0.453, 0.0, 1.0, 0.81], [0.1, 0.8, 0.0, 0.1]], + num_samples=2000, + replacement=True, + ) + # no-replacement isn't supported, unless taking only one sample + w = [0.1, 0.9] + self.assertRaises( + AssertionError, lambda: Tensor(w).multinomial(100, replacement=False) + ) + tiny_samples = [ + Tensor(w).multinomial(1, replacement=False).numpy().item() + for _ in range(1000) + ] + torch_samples = [ + torch.tensor(w).multinomial(1, replacement=False).item() + for _ in range(1000) + ] + self.assertTrue( + equal_distribution( + lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples) + ) + ) - def test_linear_init(self): - params = (64, 64) - assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach()) - assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach()) + def test_multinomial_counterexample(self): + tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) + torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True) + self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) + torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True) + self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res)) + + def test_conv2d_init(self): + params = (128, 256, (3, 3)) + assert equal_distribution( + lambda *_: nn.Conv2d(*params).weight, + lambda _: torch.nn.Conv2d(*params).weight.detach(), + ) + assert equal_distribution( + lambda *_: nn.Conv2d(*params).bias, + lambda _: torch.nn.Conv2d(*params).bias.detach(), + ) + + def test_linear_init(self): + params = (64, 64) + assert equal_distribution( + lambda *_: nn.Linear(*params).weight, + lambda _: torch.nn.Linear(*params).weight.detach(), + ) + assert equal_distribution( + lambda *_: nn.Linear(*params).bias, + lambda _: torch.nn.Linear(*params).bias.detach(), + ) + + def test_bn_init(self): + params = (64,) + assert equal_distribution( + lambda *_: nn.BatchNorm2d(*params).weight, + lambda _: torch.nn.BatchNorm2d(*params).weight.detach(), + ) + assert equal_distribution( + lambda *_: nn.BatchNorm2d(*params).bias, + lambda _: torch.nn.BatchNorm2d(*params).bias.detach(), + ) - def test_bn_init(self): - params = (64,) - assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach()) - assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach()) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/test_sample.py b/test/test_sample.py index 3bb4cf76f..6435a09c9 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -3,18 +3,23 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.shape.symbolic import Variable -class TestSample(unittest.TestCase): - def test_sample(self): - X = Tensor.rand(10000, 50).realize() - BS = 16 - idxs = np.random.randint(0, X.shape[0], size=(BS)) - # this uncovered a bug with arg sort order - batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())] - x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)]) - print(idxs) - ret = x.numpy() - base = X.numpy()[idxs] - np.testing.assert_equal(ret, base) -if __name__ == '__main__': - unittest.main() \ No newline at end of file +class TestSample(unittest.TestCase): + def test_sample(self): + X = Tensor.rand(10000, 50).realize() + BS = 16 + idxs = np.random.randint(0, X.shape[0], size=(BS)) + # this uncovered a bug with arg sort order + batch = [ + Variable(f"idx{i}", 0, X.shape[0] - 1).bind(s) + for i, s in enumerate(idxs.tolist()) + ] + x = Tensor.cat(*[X.shrink(((batch[i], batch[i] + 1), None)) for i in range(BS)]) + print(idxs) + ret = x.numpy() + base = X.numpy()[idxs] + np.testing.assert_equal(ret, base) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_schedule.py b/test/test_schedule.py index 952f99674..23b5a8de7 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -12,330 +12,375 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.graph import log_schedule_item, print_tree from tinygrad import nn -def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): - seen = set() - if to_prerealize: - for pre in to_prerealize: - for s in pre.lazydata.schedule(seen.copy()): + +def check_schedule( + t: Tensor, + allowed: int, + to_prerealize: Optional[List[Tensor]] = None, + filter_loadops=True, +): + seen = set() + if to_prerealize: + for pre in to_prerealize: + for s in pre.lazydata.schedule(seen.copy()): + log_schedule_item(s) + seen.add(s.out) + sched = t.lazydata.schedule(seen) + for s in sched: log_schedule_item(s) - seen.add(s.out) - sched = t.lazydata.schedule(seen) - for s in sched: log_schedule_item(s) - if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps] - if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") - if len(sched) != allowed or DEBUG >= 3: - for i, s in enumerate(sched): - print("op", i) - print_tree(s.ast) - assert len(sched) == allowed - # test the (non loadops) ops linearize - for s in sched: - if s.ast.op in LoadOps: continue - l = Linearizer(s.ast) - l.hand_coded_optimizations() - l.linearize() + if filter_loadops: + sched = [s for s in sched if s.ast.op not in LoadOps] + if len(sched) != allowed: + print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") + if len(sched) != allowed or DEBUG >= 3: + for i, s in enumerate(sched): + print("op", i) + print_tree(s.ast) + assert len(sched) == allowed + # test the (non loadops) ops linearize + for s in sched: + if s.ast.op in LoadOps: + continue + l = Linearizer(s.ast) + l.hand_coded_optimizations() + l.linearize() + class TestSchedule(unittest.TestCase): - def test_basic_binop_fusion(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = Tensor.empty(10) - d = a+b+c - check_schedule(d, 1) + def test_basic_binop_fusion(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = Tensor.empty(10) + d = a + b + c + check_schedule(d, 1) - def test_basic_binop_fusion_deep(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = Tensor.empty(10) - d = Tensor.empty(10) - e = a+b+c+d - check_schedule(e, 1) + def test_basic_binop_fusion_deep(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = Tensor.empty(10) + d = Tensor.empty(10) + e = a + b + c + d + check_schedule(e, 1) - def test_mulacc_fusion(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = (a*b).sum() - check_schedule(c, 1) + def test_mulacc_fusion(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = (a * b).sum() + check_schedule(c, 1) - def test_mulacc_relu_fusion(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = (a*b).sum().relu() - check_schedule(c, 1) + def test_mulacc_relu_fusion(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = (a * b).sum().relu() + check_schedule(c, 1) - def test_binop_reshape_fusion(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = Tensor.empty(5,2) - d = (a+b).reshape(5,2)+c - check_schedule(d, 1) + def test_binop_reshape_fusion(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = Tensor.empty(5, 2) + d = (a + b).reshape(5, 2) + c + check_schedule(d, 1) - def test_binop_permute_fusion(self): - a = Tensor.empty(2,5) - b = Tensor.empty(2,5) - c = Tensor.empty(5,2) - d = (a+b).permute(1,0)+c - check_schedule(d, 1) + def test_binop_permute_fusion(self): + a = Tensor.empty(2, 5) + b = Tensor.empty(2, 5) + c = Tensor.empty(5, 2) + d = (a + b).permute(1, 0) + c + check_schedule(d, 1) - @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", "only test for compiled backends") - def test_constants_are_embedded(self): - a = Tensor.empty(3,3) * 2 - check_schedule(a, 2, filter_loadops=False) + @unittest.skipIf( + not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", + "only test for compiled backends", + ) + def test_constants_are_embedded(self): + a = Tensor.empty(3, 3) * 2 + check_schedule(a, 2, filter_loadops=False) - def test_binop_elu_fusion(self): - a = Tensor.empty(10) - b = a.elu() - check_schedule(b, 1) + def test_binop_elu_fusion(self): + a = Tensor.empty(10) + b = a.elu() + check_schedule(b, 1) - def test_binop_reshape_reduce_fusion(self): - a = Tensor.empty(100) - b = Tensor.empty(100) - c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True) - check_schedule(c, 1) + def test_binop_reshape_reduce_fusion(self): + a = Tensor.empty(100) + b = Tensor.empty(100) + c = (a + b).reshape(10, 10).sum(axis=0, keepdim=True) + check_schedule(c, 1) - def test_reduce_reshape_binop_fusion(self): - a = Tensor.empty(10,10) - b = Tensor.empty(10) - c = a.sum(axis=0) + b - check_schedule(c, 1) + def test_reduce_reshape_binop_fusion(self): + a = Tensor.empty(10, 10) + b = Tensor.empty(10) + c = a.sum(axis=0) + b + check_schedule(c, 1) - @unittest.skip("not pushing permutes through reduces") - def test_reduce_permute_binop_fusion(self): - a = Tensor.empty(10,10,10) - b = Tensor.empty(10,10,1) - c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b - check_schedule(c, 1) + @unittest.skip("not pushing permutes through reduces") + def test_reduce_permute_binop_fusion(self): + a = Tensor.empty(10, 10, 10) + b = Tensor.empty(10, 10, 1) + c = a.sum(axis=0, keepdim=True).permute(2, 1, 0) + b + check_schedule(c, 1) - def test_binop_early_reshape_reduce_fusion(self): - a = Tensor.empty(100) - b = Tensor.empty(100) - c = Tensor.empty(10,10) - d = ((a+b).reshape(10,10) + c).sum(axis=0) - check_schedule(d, 1) + def test_binop_early_reshape_reduce_fusion(self): + a = Tensor.empty(100) + b = Tensor.empty(100) + c = Tensor.empty(10, 10) + d = ((a + b).reshape(10, 10) + c).sum(axis=0) + check_schedule(d, 1) - def test_diamond_folded(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = Tensor.empty(10) - d = Tensor.empty(10) - ab = a+b - e = (ab+c) + (ab+d) - check_schedule(e, 1) + def test_diamond_folded(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = Tensor.empty(10) + d = Tensor.empty(10) + ab = a + b + e = (ab + c) + (ab + d) + check_schedule(e, 1) - def test_cache_binaryop(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = a+b - d = a+b - check_schedule(d, 0, [c]) + def test_cache_binaryop(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = a + b + d = a + b + check_schedule(d, 0, [c]) - @unittest.skip("failing in old lazy") - def test_cache_binaryop_reshaped(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = a+b - d = a.reshape(10,1)+b.reshape(10,1) - check_schedule(d, 0, [c]) + @unittest.skip("failing in old lazy") + def test_cache_binaryop_reshaped(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = a + b + d = a.reshape(10, 1) + b.reshape(10, 1) + check_schedule(d, 0, [c]) - def test_cache_binaryop_transpose(self): - a = Tensor.empty(10,10) - b = Tensor.empty(10,10) - c = (a.T*b.T).T #.contiguous() - d = a*b - check_schedule(d, 0, [c]) + def test_cache_binaryop_transpose(self): + a = Tensor.empty(10, 10) + b = Tensor.empty(10, 10) + c = (a.T * b.T).T # .contiguous() + d = a * b + check_schedule(d, 0, [c]) - def test_cache_two_reduceops(self): - a = Tensor.empty(10) - b = a.sum() - c = a.sum() - bc = b+c - check_schedule(bc, 1) + def test_cache_two_reduceops(self): + a = Tensor.empty(10) + b = a.sum() + c = a.sum() + bc = b + c + check_schedule(bc, 1) - def test_fold_double_unary(self): - y = Tensor.empty(2) - out = y.sum(keepdim=True).sqrt().__neg__() - check_schedule(out, 1) + def test_fold_double_unary(self): + y = Tensor.empty(2) + out = y.sum(keepdim=True).sqrt().__neg__() + check_schedule(out, 1) - #@unittest.skip("may want to reconsider this") - def test_fold_batchnorm(self): - with Tensor.train(): - img = Tensor.empty(1,32,4,4) - bn = nn.BatchNorm2d(32, track_running_stats=False) - out = bn(img) - check_schedule(out, 3) + # @unittest.skip("may want to reconsider this") + def test_fold_batchnorm(self): + with Tensor.train(): + img = Tensor.empty(1, 32, 4, 4) + bn = nn.BatchNorm2d(32, track_running_stats=False) + out = bn(img) + check_schedule(out, 3) - def test_fold_conv_relu(self): - c1 = nn.Conv2d(3,16,3) + def test_fold_conv_relu(self): + c1 = nn.Conv2d(3, 16, 3) - # run - img = Tensor.ones(2,3,64,64) - out = c1(img).relu() - check_schedule(out, 1, [c1.weight, c1.bias]) + # run + img = Tensor.ones(2, 3, 64, 64) + out = c1(img).relu() + check_schedule(out, 1, [c1.weight, c1.bias]) - def test_fold_conv_elu(self): - c1 = nn.Conv2d(3,16,3) + def test_fold_conv_elu(self): + c1 = nn.Conv2d(3, 16, 3) - # run - img = Tensor.rand(2,3,64,64) - out = c1(img).elu() - check_schedule(out, 1, [c1.weight, c1.bias]) + # run + img = Tensor.rand(2, 3, 64, 64) + out = c1(img).elu() + check_schedule(out, 1, [c1.weight, c1.bias]) - def test_two_sum(self): - img = Tensor.empty(64,64) - x = (img.sum(0) + img.sum(1)) - out = x.relu() - del x # is 3 without this - check_schedule(out, 2) + def test_two_sum(self): + img = Tensor.empty(64, 64) + x = img.sum(0) + img.sum(1) + out = x.relu() + del x # is 3 without this + check_schedule(out, 2) - @unittest.skip("failing in old lazy") - def test_push_permute_through_reshape(self): - a = Tensor.empty(16,16) - b = Tensor.empty(16,16) - c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous() - check_schedule(c, 1) + @unittest.skip("failing in old lazy") + def test_push_permute_through_reshape(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(16, 16) + c = (a + b).reshape(4, 4, 4, 4).permute(2, 3, 0, 1).contiguous() + check_schedule(c, 1) - @unittest.skip("failing in old lazy") - def test_push_permute_through_reshape_alt(self): - a = Tensor.empty(4,4,4,4) - b = Tensor.empty(4,4,4,4) - c = (a+b).reshape(16,16).permute(1,0).contiguous() - check_schedule(c, 1) + @unittest.skip("failing in old lazy") + def test_push_permute_through_reshape_alt(self): + a = Tensor.empty(4, 4, 4, 4) + b = Tensor.empty(4, 4, 4, 4) + c = (a + b).reshape(16, 16).permute(1, 0).contiguous() + check_schedule(c, 1) - def test_no_binop_rerun(self): - a = Tensor.empty(16) - b = Tensor.empty(16) - c = a+b - d = (a+b).reshape(16,1) - check_schedule(d, 0, [c]) + def test_no_binop_rerun(self): + a = Tensor.empty(16) + b = Tensor.empty(16) + c = a + b + d = (a + b).reshape(16, 1) + check_schedule(d, 0, [c]) - def test_multi_permute_should_collapse(self): - a = Tensor.empty(4,4,4,4) - b = Tensor.empty(16) - c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b - check_schedule(c, 1) + def test_multi_permute_should_collapse(self): + a = Tensor.empty(4, 4, 4, 4) + b = Tensor.empty(16) + c = ( + a.sum((0, 1)) + .cast(dtypes.float16) + .permute(1, 0) + .reshape(4, 4, 1) + .permute(1, 0, 2) + .reshape(16) + + b + ) + check_schedule(c, 1) - @unittest.skip("failing in old lazy") - def test_fancy_reshape_fusion(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = a+b - d = a.reshape(10,1)+b.reshape(10,1) - out = c.sum() + d.sum() - check_schedule(out, 1) + @unittest.skip("failing in old lazy") + def test_fancy_reshape_fusion(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = a + b + d = a.reshape(10, 1) + b.reshape(10, 1) + out = c.sum() + d.sum() + check_schedule(out, 1) - # NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first - @unittest.skip("not real world") - def test_children_dont_push(self): - a = Tensor.empty(10, 10, 1) - b = Tensor.empty(10, 10, 1) - d = (a+b).expand(10, 10, 10) - e = (a+b).permute(2,1,0) - f = d+e - check_schedule(f, 2) + # NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first + @unittest.skip("not real world") + def test_children_dont_push(self): + a = Tensor.empty(10, 10, 1) + b = Tensor.empty(10, 10, 1) + d = (a + b).expand(10, 10, 10) + e = (a + b).permute(2, 1, 0) + f = d + e + check_schedule(f, 2) - def test_dont_fuse_binops_with_children(self): - a = Tensor.empty(10) - b = Tensor.empty(10) - c = Tensor.empty(10) - keep_me = a+b - e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse) - d = keep_me+c - check_schedule(d, 2) - check_schedule(keep_me, 0, [d]) + def test_dont_fuse_binops_with_children(self): + a = Tensor.empty(10) + b = Tensor.empty(10) + c = Tensor.empty(10) + keep_me = a + b + e = ( + keep_me.sum() + ) # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse) + d = keep_me + c + check_schedule(d, 2) + check_schedule(keep_me, 0, [d]) - @unittest.skip("failing in old lazy") - def test_permute_breaks_fusion(self): - a = Tensor.empty(10, 10, 10) - b = Tensor.empty(10, 10) - c = (a.sum(axis=2) + b).permute(1,0) - d = c.permute(1,0) - check_schedule(d, 1) + @unittest.skip("failing in old lazy") + def test_permute_breaks_fusion(self): + a = Tensor.empty(10, 10, 10) + b = Tensor.empty(10, 10) + c = (a.sum(axis=2) + b).permute(1, 0) + d = c.permute(1, 0) + check_schedule(d, 1) - def test_some_permute_fusion(self): - a = Tensor.empty(8192, 16) - b = Tensor.empty(1, 16) - d = (a.T + b.expand(8192, 16).T) - c = a + b.expand(8192, 16) - e = d.T - check_schedule(c, 1) - check_schedule(e, 1) + def test_some_permute_fusion(self): + a = Tensor.empty(8192, 16) + b = Tensor.empty(1, 16) + d = a.T + b.expand(8192, 16).T + c = a + b.expand(8192, 16) + e = d.T + check_schedule(c, 1) + check_schedule(e, 1) - # this is the failing case in openpilot...it's very simple like this - @unittest.skip("failing in old lazy") - def test_image_conv_fusion(self): - from tinygrad.features.image import image_conv2d - w1 = Tensor.empty(16, 16, 1, 1) - b1 = Tensor.empty(16) - w2 = Tensor.empty(16, 16, 1, 1) - b2 = Tensor.empty(16) - w3 = Tensor.empty(16, 16, 1, 1) - b3 = Tensor.empty(16) + # this is the failing case in openpilot...it's very simple like this + @unittest.skip("failing in old lazy") + def test_image_conv_fusion(self): + from tinygrad.features.image import image_conv2d - x = Tensor.empty(1, 16, 32, 32) - x = base = image_conv2d(x, w1, b1) - x = image_conv2d(x, w2, b2) + base - x = image_conv2d(x, w3, b3) + w1 = Tensor.empty(16, 16, 1, 1) + b1 = Tensor.empty(16) + w2 = Tensor.empty(16, 16, 1, 1) + b2 = Tensor.empty(16) + w3 = Tensor.empty(16, 16, 1, 1) + b3 = Tensor.empty(16) - # NOOP, 3 convs, contiguous - check_schedule(x, 5) + x = Tensor.empty(1, 16, 32, 32) + x = base = image_conv2d(x, w1, b1) + x = image_conv2d(x, w2, b2) + base + x = image_conv2d(x, w3, b3) - def test_image_conv_fusion_minimal(self): - b1 = Tensor.empty(16) - b2 = Tensor.empty(16) - def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0) + # NOOP, 3 convs, contiguous + check_schedule(x, 5) - x = Tensor.empty(16, 32) - x = base = p(x) + b1.reshape(16,1) - x = p(x) - x = x + b2.reshape(16,1) - x = x + base - del base - x = p(x) - check_schedule(x, 4) + def test_image_conv_fusion_minimal(self): + b1 = Tensor.empty(16) + b2 = Tensor.empty(16) - def test_image_conv_fusion_more_minimal(self): - b1 = Tensor.empty(16) - def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0) + def p(x): + return ( + x.permute(1, 0) + .contiguous() + .reshape(32, 16, 1) + .expand(32, 16, 16) + .sum(axis=2) + .permute(1, 0) + ) - x = Tensor.empty(16, 32) - x = base = p(x) + b1.reshape(16,1) - x = p(x) - del base - check_schedule(x, 3) + x = Tensor.empty(16, 32) + x = base = p(x) + b1.reshape(16, 1) + x = p(x) + x = x + b2.reshape(16, 1) + x = x + base + del base + x = p(x) + check_schedule(x, 4) - def test_resnet_block(self): - from extra.models.resnet import BasicBlock - Tensor.training = False - bb = BasicBlock(64,64) + def test_image_conv_fusion_more_minimal(self): + b1 = Tensor.empty(16) - x = Tensor.empty(1, 64, 32, 32) - out = bb(x) - check_schedule(out, 4) + def p(x): + return ( + x.permute(1, 0) + .contiguous() + .reshape(32, 16, 1) + .expand(32, 16, 16) + .sum(axis=2) + .permute(1, 0) + ) - def test_contiguous_while_contiguous(self): - x = Tensor.empty(1, 64, 32, 32) - out = x.contiguous() - check_schedule(out, 1, filter_loadops=False) + x = Tensor.empty(16, 32) + x = base = p(x) + b1.reshape(16, 1) + x = p(x) + del base + check_schedule(x, 3) - def test_contiguous_while_not_contiguous(self): - x = Tensor.empty(1, 64, 32, 32) - out = x.permute(0,2,3,1).contiguous() - check_schedule(out, 2, filter_loadops=False) + def test_resnet_block(self): + from extra.models.resnet import BasicBlock - def test_double_from(self): - x = Tensor([1,2,3,4]) - out = x.to('cpu') - check_schedule(out, 0, filter_loadops=False) + Tensor.training = False + bb = BasicBlock(64, 64) - def test_pow_const_tensor(self): - x = Tensor([1,2,3,4]) - out = x ** Tensor(2) - check_schedule(out, 1) + x = Tensor.empty(1, 64, 32, 32) + out = bb(x) + check_schedule(out, 4) - def test_zero_size(self): - x = Tensor.rand(2, 3, 0) - out = x + 1 - check_schedule(out, 0, filter_loadops=False) + def test_contiguous_while_contiguous(self): + x = Tensor.empty(1, 64, 32, 32) + out = x.contiguous() + check_schedule(out, 1, filter_loadops=False) -if __name__ == '__main__': - unittest.main(verbosity=2) + def test_contiguous_while_not_contiguous(self): + x = Tensor.empty(1, 64, 32, 32) + out = x.permute(0, 2, 3, 1).contiguous() + check_schedule(out, 2, filter_loadops=False) + + def test_double_from(self): + x = Tensor([1, 2, 3, 4]) + out = x.to("cpu") + check_schedule(out, 0, filter_loadops=False) + + def test_pow_const_tensor(self): + x = Tensor([1, 2, 3, 4]) + out = x ** Tensor(2) + check_schedule(out, 1) + + def test_zero_size(self): + x = Tensor.rand(2, 3, 0) + out = x + 1 + check_schedule(out, 0, filter_loadops=False) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/test_search.py b/test/test_search.py index c4fefbfe0..e967eac7c 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -6,15 +6,24 @@ from tinygrad.device import Compiled, Device, Buffer from tinygrad.ops import LoadOps from tinygrad.tensor import Tensor + class TestTimeLinearizer(unittest.TestCase): - def setUp(self) -> None: - if not isinstance(Device[Device.DEFAULT], Compiled): raise unittest.SkipTest("only test for compiled backends") + def setUp(self) -> None: + if not isinstance(Device[Device.DEFAULT], Compiled): + raise unittest.SkipTest("only test for compiled backends") - def test_reasonable_time(self): - si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0] - rawbufs = [Buffer(Device.DEFAULT, si.out.st.size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.size(), x.dtype) for x in si.inputs] - tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) - assert tm > 0 and tm != float('inf') + def test_reasonable_time(self): + si = [ + si + for si in Tensor([1, 2, 3, 4]).add(1).lazydata.schedule() + if si.ast.op not in LoadOps + ][0] + rawbufs = [Buffer(Device.DEFAULT, si.out.st.size(), si.out.dtype)] + [ + Buffer(Device.DEFAULT, x.st.size(), x.dtype) for x in si.inputs + ] + tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) + assert tm > 0 and tm != float("inf") -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_specific_conv.py b/test/test_specific_conv.py index c8bab9716..201641224 100644 --- a/test/test_specific_conv.py +++ b/test/test_specific_conv.py @@ -2,54 +2,71 @@ import unittest from tinygrad.tensor import Tensor from tinygrad.helpers import CI, dtypes from tinygrad import Device + # similar to test/external/external_test_gpu_ast.py, but universal + @unittest.skipIf(Device.DEFAULT == "CUDA" and CI, "slow on CUDA CI") class TestSpecific(unittest.TestCase): - # from openpilot + # from openpilot - # 1x1 6 <- 24 - def test_1x1_6_24(self): - x = Tensor.randn(1, 24*4, 32, 64) - w = Tensor.randn(6*4, 24*4, 1, 1) - x.conv2d(w).permute(0,2,3,1).reshape(32, 384, 4).contiguous().realize() + # 1x1 6 <- 24 + def test_1x1_6_24(self): + x = Tensor.randn(1, 24 * 4, 32, 64) + w = Tensor.randn(6 * 4, 24 * 4, 1, 1) + x.conv2d(w).permute(0, 2, 3, 1).reshape(32, 384, 4).contiguous().realize() - def test_vec_mul(self): - # this forces it to be an image... - x = Tensor.ones(1, 512, 4).contiguous().reshape(1, 2048) - w = Tensor.randn(2048, 512) - (x @ w).reshape(1, 128, 4).contiguous().realize() + def test_vec_mul(self): + # this forces it to be an image... + x = Tensor.ones(1, 512, 4).contiguous().reshape(1, 2048) + w = Tensor.randn(2048, 512) + (x @ w).reshape(1, 128, 4).contiguous().realize() - @unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], "Broken on LLVM and webgpu, GPU requires cl_khr_fp16") - def test_big_vec_mul(self): - # from LLaMA - # 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)] - # 1 buffer<4096, dtypes.float> [View((1024, 1024, 4, 4), (0, 4, 1, 0), 0, None)] - # 2 buffer<16777216, dtypes.half> [View((1024, 1024, 4, 4), (16384, 4, 1, 4096), 0, None)] - x = Tensor.randn(4096).realize() - w = Tensor.randn(4096, 4096, device='cpu').cast(dtypes.float16).to(Device.DEFAULT).realize() - (x @ w.T).realize() + @unittest.skipIf( + Device.DEFAULT in ["LLVM", "WEBGPU", "GPU", "CUDA"], + "Broken on LLVM and webgpu, GPU requires cl_khr_fp16", + ) + def test_big_vec_mul(self): + # from LLaMA + # 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)] + # 1 buffer<4096, dtypes.float> [View((1024, 1024, 4, 4), (0, 4, 1, 0), 0, None)] + # 2 buffer<16777216, dtypes.half> [View((1024, 1024, 4, 4), (16384, 4, 1, 4096), 0, None)] + x = Tensor.randn(4096).realize() + w = ( + Tensor.randn(4096, 4096, device="cpu") + .cast(dtypes.float16) + .to(Device.DEFAULT) + .realize() + ) + (x @ w.T).realize() - # from https://dl.acm.org/doi/pdf/10.1145/3495243.3517020 + # from https://dl.acm.org/doi/pdf/10.1145/3495243.3517020 - # ~260 GFLOPS on Adreno 640, should be 260*(720/890)*(596/710) = 176.5 on downclocked 630 - # we get 170 - def test_1x1_28_28(self): - x = Tensor.randn(1, 256, 28, 28) - w = Tensor.randn(256, 256, 1, 1) - x.conv2d(w).permute(0,2,3,1).reshape(28, 28*256//4, 4).contiguous().realize() + # ~260 GFLOPS on Adreno 640, should be 260*(720/890)*(596/710) = 176.5 on downclocked 630 + # we get 170 + def test_1x1_28_28(self): + x = Tensor.randn(1, 256, 28, 28) + w = Tensor.randn(256, 256, 1, 1) + x.conv2d(w).permute(0, 2, 3, 1).reshape( + 28, 28 * 256 // 4, 4 + ).contiguous().realize() - # 132 GFLOPS on Adreno 640, should be 132*(720/890)*(596/710) = 90 on downclocked 630 - # gets 54 with broken opt, 74 without opt, and 146 if we pad and opt 3! - def test_3x3_28_28_stride_2(self): - x = Tensor.randn(1, 288, 36, 36) - w = Tensor.randn(384, 288, 3, 3) - x.conv2d(w, stride=2).permute(0,2,3,1).reshape(17, 17*384//4, 4).contiguous().realize() + # 132 GFLOPS on Adreno 640, should be 132*(720/890)*(596/710) = 90 on downclocked 630 + # gets 54 with broken opt, 74 without opt, and 146 if we pad and opt 3! + def test_3x3_28_28_stride_2(self): + x = Tensor.randn(1, 288, 36, 36) + w = Tensor.randn(384, 288, 3, 3) + x.conv2d(w, stride=2).permute(0, 2, 3, 1).reshape( + 17, 17 * 384 // 4, 4 + ).contiguous().realize() - def test_3x3_28_28_stride_2_padded(self): - x = Tensor.randn(1, 288, 36, 36) - w = Tensor.randn(384, 288, 3, 3) - x.conv2d(w, stride=2, padding=1).permute(0,2,3,1).reshape(18, 18*384//4, 4).contiguous().realize() + def test_3x3_28_28_stride_2_padded(self): + x = Tensor.randn(1, 288, 36, 36) + w = Tensor.randn(384, 288, 3, 3) + x.conv2d(w, stride=2, padding=1).permute(0, 2, 3, 1).reshape( + 18, 18 * 384 // 4, 4 + ).contiguous().realize() -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index fe6333e0d..b2b7bfd56 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -1,13 +1,16 @@ import os + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" import unittest import torch + torch.set_num_threads(1) import time import numpy as np + np.set_printoptions(linewidth=160) from tinygrad import Device from tinygrad.helpers import GlobalCounters @@ -17,272 +20,498 @@ from tinygrad.helpers import colored, getenv, CI from tinygrad.jit import TinyJit import pytest -pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang] +pytestmark = [ + pytest.mark.exclude_cuda, + pytest.mark.exclude_gpu, + pytest.mark.exclude_clang, +] IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")] torch_dt = torch.float16 if getenv("HALF", 0) else torch.float32 -torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu')) +torch_device = torch.device( + "mps" if getenv("MPS", 0) else ("cuda" if getenv("TORCHCUDA", 0) else "cpu") +) if str(torch_device) == "mps": - import torch.mps - def sync(): torch.mps.synchronize() + import torch.mps + + def sync(): + torch.mps.synchronize() + elif str(torch_device) == "cuda": - import torch.cuda - def sync(): torch.cuda.synchronize() + import torch.cuda + + def sync(): + torch.cuda.synchronize() + else: - def sync(): pass + + def sync(): + pass + def colorize_float(x): - ret = f"{x:7.2f}x" - if x < 0.75: - return colored(ret, 'green') - elif x > 1.15: - return colored(ret, 'red') - else: - return colored(ret, 'yellow') + ret = f"{x:7.2f}x" + if x < 0.75: + return colored(ret, "green") + elif x > 1.15: + return colored(ret, "red") + else: + return colored(ret, "yellow") + save_ops, save_mem = 0, 0 CNT = getenv("CNT", 8) + + def helper_test_speed(f1, *args): - global save_ops, save_mem - ets = [] - ret = None - cache_defeat = np.zeros((2048,2048)) - for i in range(CNT): - del ret + global save_ops, save_mem + ets = [] + ret = None + cache_defeat = np.zeros((2048, 2048)) + for i in range(CNT): + del ret - # operation cache defeats - args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] + # operation cache defeats + args = [ + (x + 1).realize() + if isinstance(x, Tensor) + else (None if x is None else (x + 1)) + for x in args + ] - # force syncing - [x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None] + # force syncing + [ + x.numpy() + if isinstance(x, Tensor) or str(torch_device) == "cpu" + else x.cpu().numpy() + for x in args + if x is not None + ] - # clear 32MB global memory cache (CPU and global memory only) - cache_defeat += 1 + # clear 32MB global memory cache (CPU and global memory only) + cache_defeat += 1 - # manual pre sync - if isinstance(args[0], Tensor): Device[args[0].device].synchronize() - else: sync() + # manual pre sync + if isinstance(args[0], Tensor): + Device[args[0].device].synchronize() + else: + sync() + + GlobalCounters.global_ops = 0 + GlobalCounters.global_mem = 0 + st = time.perf_counter() + ret = f1(*args) + if isinstance(ret, Tensor): + Device[ret.device].synchronize() + else: + sync() + et = (time.perf_counter() - st) * 1000 + if i >= 1: + ets.append(et) + if GlobalCounters.global_ops: + save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem + return ret.numpy() if isinstance(ret, Tensor) else ret.cpu().numpy(), np.min(ets) - GlobalCounters.global_ops = 0 - GlobalCounters.global_mem = 0 - st = time.perf_counter() - ret = f1(*args) - if isinstance(ret, Tensor): Device[ret.device].synchronize() - else: sync() - et = (time.perf_counter() - st) * 1000 - if i >= 1: ets.append(et) - if GlobalCounters.global_ops: - save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem - return ret.numpy() if isinstance(ret, Tensor) else ret.cpu().numpy(), np.min(ets) def helper_test_generic_square(name, N, f1, f2, onearg=False): - torch.manual_seed(0) - torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) - torch_b = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) if not onearg else None + torch.manual_seed(0) + torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) + torch_b = ( + (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) + if not onearg + else None + ) - tiny_a = Tensor(torch_a.cpu().numpy()) - tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None + tiny_a = Tensor(torch_a.cpu().numpy()) + tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None + + helper_test_generic( + f"{name:30s} {N:5d}x{N:5d}", + f1, + (torch_a, torch_b), + TinyJit(lambda a, b: f2(a, b).realize()), + (tiny_a, tiny_b), + ) - helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b)) def helper_test_matvec(name, N, M): - torch.manual_seed(0) - torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device) - torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device) + torch.manual_seed(0) + torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device) + torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device) - tiny_a = Tensor(torch_a.cpu().numpy()) - tiny_b = Tensor(torch_b.cpu().numpy()) + tiny_a = Tensor(torch_a.cpu().numpy()) + tiny_b = Tensor(torch_b.cpu().numpy()) + + helper_test_generic( + f"{name:30s} {N:5d}x{M:5d}", + lambda a, b: a @ b, + (torch_a, torch_b), + TinyJit(lambda a, b: (a @ b).realize()), + (tiny_a, tiny_b), + ) - helper_test_generic(f"{name:30s} {N:5d}x{M:5d}", lambda a,b: a@b, (torch_a, torch_b), TinyJit(lambda a,b:(a@b).realize()), (tiny_a, tiny_b)) prefix = None -def helper_test_generic(name, f1, f1_args, f2, f2_args): - global prefix - with torch.no_grad(): - val_torch, et_torch = helper_test_speed(f1, *f1_args) - val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args) - desc = "faster" if et_torch > et_tinygrad else "slower" - flops = save_ops*1e-6 - mem = save_mem*1e-6 - print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") - np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3) + +def helper_test_generic(name, f1, f1_args, f2, f2_args): + global prefix + with torch.no_grad(): + val_torch, et_torch = helper_test_speed(f1, *f1_args) + val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args) + + desc = "faster" if et_torch > et_tinygrad else "slower" + flops = save_ops * 1e-6 + mem = save_mem * 1e-6 + print( + ("\r" if not CI else "") + + f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB" + ) + np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3) + def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x): - torch.manual_seed(0) - torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to(torch_device) - torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt).to(torch_device) + torch.manual_seed(0) + torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to( + torch_device + ) + torch_conv = torch.nn.Conv2d( + in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt + ).to(torch_device) - tiny_dat = Tensor(torch_dat.cpu().numpy()) - tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None) - tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + tiny_dat = Tensor(torch_dat.cpu().numpy()) + tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None) + tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + + def f1(torch_dat): + return torch_conv(torch_dat) + + def f2(tiny_dat): + return tiny_conv(tiny_dat).realize() + + helper_test_generic( + f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", + f1, + (torch_dat,), + TinyJit(f2), + (tiny_dat,), + ) - def f1(torch_dat): return torch_conv(torch_dat) - def f2(tiny_dat): return tiny_conv(tiny_dat).realize() - helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,)) @unittest.skipIf(getenv("BIG") == 0, "no big tests") class TestBigSpeed(unittest.TestCase): - def test_add(self): - def f(a, b): return a+b - helper_test_generic_square('add', 8192, f, f) - def test_exp(self): - def f(a, b): return a.exp() - helper_test_generic_square('exp', 8192, f, f, onearg=True) - def test_gemm_2048(self): - def f(a, b): return a @ b - helper_test_generic_square('gemm', 2048, f, f) - def test_gemm_4096(self): - def f(a, b): return a @ b - helper_test_generic_square('gemm', 4096, f, f) - def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128) - def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130) - def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=132, img_size_x=132) - def test_matvec_4096_16384(self): helper_test_matvec('matvec_4096_16384', 4096, 16384) - def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096) + def test_add(self): + def f(a, b): + return a + b + + helper_test_generic_square("add", 8192, f, f) + + def test_exp(self): + def f(a, b): + return a.exp() + + helper_test_generic_square("exp", 8192, f, f, onearg=True) + + def test_gemm_2048(self): + def f(a, b): + return a @ b + + helper_test_generic_square("gemm", 2048, f, f) + + def test_gemm_4096(self): + def f(a, b): + return a @ b + + helper_test_generic_square("gemm", 4096, f, f) + + def test_large_conv_1x1(self): + helper_test_conv( + bs=32, + in_chans=128, + out_chans=128, + kernel_size=1, + img_size_y=128, + img_size_x=128, + ) + + def test_large_conv_3x3(self): + helper_test_conv( + bs=4, + in_chans=128, + out_chans=128, + kernel_size=3, + img_size_y=130, + img_size_x=130, + ) + + def test_large_conv_5x5(self): + helper_test_conv( + bs=4, + in_chans=128, + out_chans=128, + kernel_size=5, + img_size_y=132, + img_size_x=132, + ) + + def test_matvec_4096_16384(self): + helper_test_matvec("matvec_4096_16384", 4096, 16384) + + def test_matvec_16384_4096(self): + helper_test_matvec("matvec_16384_4096", 16384, 4096) + @unittest.skipIf(getenv("BIG") == 1, "only big tests") class TestSpeed(unittest.TestCase): - def test_sub(self): - def f(a, b): return a-b - helper_test_generic_square('sub', 4096, f, f) + def test_sub(self): + def f(a, b): + return a - b - @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "breaking on webgpu CI") - def test_pow(self): - def f(a, b): return a.pow(b) - helper_test_generic_square('pow', 2048, f, f) + helper_test_generic_square("sub", 4096, f, f) - def test_sum(self): - def f(a, b): return a.sum() - helper_test_generic_square('sum', 2048, f, f, onearg=True) - helper_test_generic_square('sum', 4096, f, f, onearg=True) + @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "breaking on webgpu CI") + def test_pow(self): + def f(a, b): + return a.pow(b) - def test_partial_sum(self): - R = 256 - def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1) - helper_test_generic_square('partial_sum', 4096, f, f, onearg=True) + helper_test_generic_square("pow", 2048, f, f) - @unittest.skip("not really used in models") - def test_cumsum(self): - def f0(a, b): return a.cumsum(axis=0) - def f1(a, b): return a.cumsum(axis=1) - helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True) - helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True) + def test_sum(self): + def f(a, b): + return a.sum() - def test_cat(self): - helper_test_generic_square('cat_0', 256, lambda x,y: torch.cat((x,y),dim=0), lambda x,y: x.cat(y,dim=0)) - helper_test_generic_square('cat_1', 256, lambda x,y: torch.cat((x,y),dim=1), lambda x,y: x.cat(y,dim=1)) + helper_test_generic_square("sum", 2048, f, f, onearg=True) + helper_test_generic_square("sum", 4096, f, f, onearg=True) - def test_array_packing(self): - N = 2048 - def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous() - helper_test_generic_square('array_packing', N, f, f, onearg=True) + def test_partial_sum(self): + R = 256 - def test_permute(self): - for N in [1024, 4096]: - # this is a 64MB tensor, M1 L1 cache is 128kB - # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size - def f(a, b): return a.permute(1,0).contiguous() - helper_test_generic_square('permute', N, f, f, onearg=True) + def f(a, b): + return a.reshape(int(4096 // R), int(4096 * R)).sum(axis=1) - def test_double_permute(self): - N = 64 - torch.manual_seed(0) - torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device) - tiny_a = Tensor(torch_a.cpu().numpy()) - def f(a): return a.permute(1,0,3,2).contiguous() - helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,)) + helper_test_generic_square("partial_sum", 4096, f, f, onearg=True) - def test_neg(self): - def f(a, b): return -a - helper_test_generic_square('neg', 4096, f, f, onearg=True) + @unittest.skip("not really used in models") + def test_cumsum(self): + def f0(a, b): + return a.cumsum(axis=0) - def test_exp(self): - def f(a, b): return a.exp() - helper_test_generic_square('exp', 2048, f, f, onearg=True) + def f1(a, b): + return a.cumsum(axis=1) - def test_relu(self): - def f(a, b): return a.relu() - helper_test_generic_square('relu', 4096, f, f, onearg=True) + helper_test_generic_square("cumsum_0", 256, f0, f0, onearg=True) + helper_test_generic_square("cumsum_1", 256, f1, f1, onearg=True) - def test_max(self): - def f(a, b): return a.max() - helper_test_generic_square('max', 4096, f, f, onearg=True) + def test_cat(self): + helper_test_generic_square( + "cat_0", + 256, + lambda x, y: torch.cat((x, y), dim=0), + lambda x, y: x.cat(y, dim=0), + ) + helper_test_generic_square( + "cat_1", + 256, + lambda x, y: torch.cat((x, y), dim=1), + lambda x, y: x.cat(y, dim=1), + ) - def test_mul_sum(self): - def f(a, b): return (a*b).sum() - helper_test_generic_square('mul_sum', 4096, f, f) + def test_array_packing(self): + N = 2048 - def test_add(self): - for N in [1, 1024, 4096]: - def f(a, b): return a + b - helper_test_generic_square('add', N, f, f) + def f(a, b): + return a.reshape(N, N // 32, 32).permute(1, 0, 2).contiguous() - def test_add_constant(self): - def f(a, b): return a+2.0 - helper_test_generic_square('add_constant', 4096, f, f, onearg=True) + helper_test_generic_square("array_packing", N, f, f, onearg=True) - def test_add_sq(self): - def f(a, b): return a*a + b*b - helper_test_generic_square('add_sq', 4096, f, f) + def test_permute(self): + for N in [1024, 4096]: + # this is a 64MB tensor, M1 L1 cache is 128kB + # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size + def f(a, b): + return a.permute(1, 0).contiguous() - def test_gemm(self): - def f(a, b): return a @ b - helper_test_generic_square('gemm', 1024, f, f) + helper_test_generic_square("permute", N, f, f, onearg=True) - def test_gemm_small(self): - def f(a, b): return a @ b - helper_test_generic_square('gemm', 256, f, f) + def test_double_permute(self): + N = 64 + torch.manual_seed(0) + torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device) + tiny_a = Tensor(torch_a.cpu().numpy()) - def test_gemm_unrolled(self): - N = 512 - def f1(a, b): return a@b.T - def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2) - helper_test_generic_square('gemm_unrolled', N, f1, f2) + def f(a): + return a.permute(1, 0, 3, 2).contiguous() - def test_gemm_unrolled_permute_l(self): - N = 512 - def f1(a, b): return a.T@b.T - def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2) - helper_test_generic_square('gemm_unrolled_permute_l', N, f1, f2) + helper_test_generic( + f"double_permute {tiny_a.shape}", + f, + (torch_a,), + TinyJit(lambda a: f(a).realize()), + (tiny_a,), + ) - def test_gemm_unrolled_permute_r(self): - N = 512 - def f1(a, b): return a@b - def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) - helper_test_generic_square('gemm_unrolled_permute_r', N, f1, f2) + def test_neg(self): + def f(a, b): + return -a - def test_gemm_unrolled_permute_lr(self): - N = 512 - def f1(a, b): return a.T@b - def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) - helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2) + helper_test_generic_square("neg", 4096, f, f, onearg=True) - def test_matvec_1024_1024(self): helper_test_matvec('matvec_1024_1024', 1024, 1024) - def test_matvec_1024_4096(self): helper_test_matvec('matvec_1024_4096', 1024, 4096) - def test_matvec_4096_1024(self): helper_test_matvec('matvec_4096_1024', 4096, 1024) - def test_matvec_4096_4096(self): helper_test_matvec('matvec_4096_4096', 4096, 4096) + def test_exp(self): + def f(a, b): + return a.exp() - def test_openpilot_conv2d(self): - bs, in_chans, out_chans = 1,12,32 - torch.manual_seed(0) - torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device) - torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt).to(torch_device) + helper_test_generic_square("exp", 2048, f, f, onearg=True) - tiny_dat = Tensor(torch_dat.cpu().numpy()) - tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) - tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + def test_relu(self): + def f(a, b): + return a.relu() - def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2)) - def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize() - helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:3", f1, (torch_dat,), TinyJit(f2), (tiny_dat,)) + helper_test_generic_square("relu", 4096, f, f, onearg=True) - def test_conv2d(self): - for bs in [32]: - for in_chans in IN_CHANS: - for out_chans in [32]: - helper_test_conv(bs, in_chans, out_chans, 3, 34, 34) + def test_max(self): + def f(a, b): + return a.max() -if __name__ == '__main__': - unittest.main() + helper_test_generic_square("max", 4096, f, f, onearg=True) + + def test_mul_sum(self): + def f(a, b): + return (a * b).sum() + + helper_test_generic_square("mul_sum", 4096, f, f) + + def test_add(self): + for N in [1, 1024, 4096]: + + def f(a, b): + return a + b + + helper_test_generic_square("add", N, f, f) + + def test_add_constant(self): + def f(a, b): + return a + 2.0 + + helper_test_generic_square("add_constant", 4096, f, f, onearg=True) + + def test_add_sq(self): + def f(a, b): + return a * a + b * b + + helper_test_generic_square("add_sq", 4096, f, f) + + def test_gemm(self): + def f(a, b): + return a @ b + + helper_test_generic_square("gemm", 1024, f, f) + + def test_gemm_small(self): + def f(a, b): + return a @ b + + helper_test_generic_square("gemm", 256, f, f) + + def test_gemm_unrolled(self): + N = 512 + + def f1(a, b): + return a @ b.T + + def f2(a, b): + return ( + a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N) + ).sum(axis=2) + + helper_test_generic_square("gemm_unrolled", N, f1, f2) + + def test_gemm_unrolled_permute_l(self): + N = 512 + + def f1(a, b): + return a.T @ b.T + + def f2(a, b): + return ( + a.permute(1, 0).reshape(N, 1, N).expand(N, N, N) + * b.reshape(1, N, N).expand(N, N, N) + ).sum(axis=2) + + helper_test_generic_square("gemm_unrolled_permute_l", N, f1, f2) + + def test_gemm_unrolled_permute_r(self): + N = 512 + + def f1(a, b): + return a @ b + + def f2(a, b): + return ( + a.reshape(N, 1, N).expand(N, N, N) + * b.permute(1, 0).reshape(1, N, N).expand(N, N, N) + ).sum(axis=2) + + helper_test_generic_square("gemm_unrolled_permute_r", N, f1, f2) + + def test_gemm_unrolled_permute_lr(self): + N = 512 + + def f1(a, b): + return a.T @ b + + def f2(a, b): + return ( + a.permute(1, 0).reshape(N, 1, N).expand(N, N, N) + * b.permute(1, 0).reshape(1, N, N).expand(N, N, N) + ).sum(axis=2) + + helper_test_generic_square("gemm_unrolled_permute_lr", N, f1, f2) + + def test_matvec_1024_1024(self): + helper_test_matvec("matvec_1024_1024", 1024, 1024) + + def test_matvec_1024_4096(self): + helper_test_matvec("matvec_1024_4096", 1024, 4096) + + def test_matvec_4096_1024(self): + helper_test_matvec("matvec_4096_1024", 4096, 1024) + + def test_matvec_4096_4096(self): + helper_test_matvec("matvec_4096_4096", 4096, 4096) + + def test_openpilot_conv2d(self): + bs, in_chans, out_chans = 1, 12, 32 + torch.manual_seed(0) + torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device) + torch_conv = torch.nn.Conv2d( + in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt + ).to(torch_device) + + tiny_dat = Tensor(torch_dat.cpu().numpy()) + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) + tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + + def f1(torch_dat): + return torch_conv(torch_dat.permute(0, 3, 1, 2)) + + def f2(tiny_dat): + return tiny_conv(tiny_dat.permute(0, 3, 1, 2)).realize() + + helper_test_generic( + f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:3", + f1, + (torch_dat,), + TinyJit(f2), + (tiny_dat,), + ) + + def test_conv2d(self): + for bs in [32]: + for in_chans in IN_CHANS: + for out_chans in [32]: + helper_test_conv(bs, in_chans, out_chans, 3, 34, 34) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 481d72c8a..aee497f0b 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -7,177 +7,214 @@ from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor import numpy as np + @unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") class TestSymbolicJit(unittest.TestCase): - def test_plus1(self): - def f(a): return (a+1).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_plus1(self): + def f(a): + return (a + 1).realize() - def test_add(self): - def f(a, b): return (a+b).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + symbolic = jf(a.reshape(3, vi)).reshape(3, i).numpy() + expected = f(a).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) - def test_matmul(self): - def f(a, b): return (a@b).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_add(self): + def f(a, b): + return (a + b).realize() - def test_mixed_with_no_symbol_kernel(self): - def f(a, b): - s = (a@b).realize() - s = (s+s).realize() # this one does not have symbols in input - return s - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 2) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(3, i) + symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) - def test_attention(self): - def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) - symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 6) + def test_matmul(self): + def f(a, b): + return (a @ b).realize() - def test_cat_dim0(self): - def f(a, b): return a.cat(b, dim=0).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - b = Tensor.rand(2, 3) - symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(i, 5) + symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) - def test_cat_dim1(self): - def f(a, b): return a.cat(b, dim=1).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, 2) - symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_mixed_with_no_symbol_kernel(self): + def f(a, b): + s = (a @ b).realize() + s = (s + s).realize() # this one does not have symbols in input + return s - def test_cat_dim0_two_vars(self): - def f(a, b): return a.cat(b, dim=0).realize() - jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(i, 5) + symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 2) - def test_cat_dim1_two_vars(self): - def f(a, b): return a.cat(b, dim=1).realize() - jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_attention(self): + def f(q, k, v): + return Tensor.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).realize() - def test_two_vars_plus1_ij(self): - def f(a, b): return (a@b+1).realize() - jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, i, 4, 8) + v = Tensor.rand(2, i, 4, 8) + symbolic = ( + jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)) + .reshape(2, 4, 1, 8) + .numpy() + ) + expected = f(q, k, v).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 6) - def test_two_vars_plus1_ji(self): - def f(a, b): return (a@b+1).realize() - jf = TinyJit(f) - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_cat_dim0(self): + def f(a, b): + return a.cat(b, dim=0).realize() - def test_jit_symbolic_shape_mismatch(self): - @TinyJit - def add(a, b): return (a+b).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - b = Tensor.rand(3, i).reshape(3, vi) - add(a, b) - vi2 = Variable("i", 1, 10).bind(7) - a = Tensor.rand(3, 7).reshape(3, vi2) - bad = Tensor.rand(4, 7).reshape(4, vi2) - with self.assertRaises(AssertionError): - add(a, bad) + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(i, 3) + b = Tensor.rand(2, 3) + symbolic = jf(a.reshape(vi, 3), b).reshape(i + 2, 3).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) - def test_shrink(self): - # shrink is a movement, so we pair it with a simple function to test the JIT interaction - def f(a): return (a+1).realize() - jf = TinyJit(f) - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) - symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic = jf(symbolic).numpy() - expected = f(a.shrink(((3,5),(i,i+2)))).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert_jit_cache_len(jf, 1) + def test_cat_dim1(self): + def f(a, b): + return a.cat(b, dim=1).realize() -if __name__ == '__main__': - unittest.main() \ No newline at end of file + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(3, 2) + symbolic = jf(a.reshape(3, vi), b).reshape(3, i + 2).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + def test_cat_dim0_two_vars(self): + def f(a, b): + return a.cat(b, dim=0).realize() + + jf = TinyJit(f) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(i, 3) + b = Tensor.rand(j, 3) + symbolic = ( + jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i + j, 3).numpy() + ) + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + def test_cat_dim1_two_vars(self): + def f(a, b): + return a.cat(b, dim=1).realize() + + jf = TinyJit(f) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(3, i) + b = Tensor.rand(3, j) + symbolic = ( + jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i + j).numpy() + ) + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + def test_two_vars_plus1_ij(self): + def f(a, b): + return (a @ b + 1).realize() + + jf = TinyJit(f) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(i, 3) + b = Tensor.rand(3, j) + symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + def test_two_vars_plus1_ji(self): + def f(a, b): + return (a @ b + 1).realize() + + jf = TinyJit(f) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(j, 3) + b = Tensor.rand(3, i) + symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + def test_jit_symbolic_shape_mismatch(self): + @TinyJit + def add(a, b): + return (a + b).realize() + + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i).reshape(3, vi) + b = Tensor.rand(3, i).reshape(3, vi) + add(a, b) + vi2 = Variable("i", 1, 10).bind(7) + a = Tensor.rand(3, 7).reshape(3, vi2) + bad = Tensor.rand(4, 7).reshape(4, vi2) + with self.assertRaises(AssertionError): + add(a, bad) + + def test_shrink(self): + # shrink is a movement, so we pair it with a simple function to test the JIT interaction + def f(a): + return (a + 1).realize() + + jf = TinyJit(f) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(7, 11) + symbolic = a.shrink(((3, 5), (vi, vi + 2))) + symbolic = jf(symbolic).numpy() + expected = f(a.shrink(((3, 5), (i, i + 2)))).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert_jit_cache_len(jf, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 367e9daea..798a414b0 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -4,133 +4,168 @@ from tinygrad.helpers import getenv from tinygrad.tensor import Tensor, Device import numpy as np + @unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") @unittest.skipIf(Device.DEFAULT in ["WEBGPU"], f"{Device.DEFAULT} is not supported") class TestSymbolicOps(unittest.TestCase): - def test_plus1(self): - def f(a): return (a+1).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_plus1(self): + def f(a): + return (a + 1).realize() - def test_add(self): - def f(a, b): return (a+b).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + symbolic = f(a.reshape(3, vi)).reshape(3, i).numpy() + expected = f(a).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_matmul(self): - def f(a, b): return (a@b).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(i, 5) - symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_add(self): + def f(a, b): + return (a + b).realize() - def test_attention(self, dropout_p=0.0): - def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - q = Tensor.rand(2, 1, 4, 8) - k = Tensor.rand(2, i, 4, 8) - v = Tensor.rand(2, i, 4, 8) - symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() - expected = f(q, k, v).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(3, i) + symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_attention_training(self): - with Tensor.train(): - self.test_attention(dropout_p=0.0) - with self.assertRaises(AssertionError): - # symbolic shape dropout is not supported - self.test_attention(dropout_p=0.5) + def test_matmul(self): + def f(a, b): + return (a @ b).realize() - def test_cat_dim0(self): - def f(a, b): return a.cat(b, dim=0).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(i, 3) - b = Tensor.rand(2, 3) - symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(i, 5) + symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_cat_dim1(self): - def f(a, b): return a.cat(b, dim=1).realize() - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(3, i) - b = Tensor.rand(3, 2) - symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_attention(self, dropout_p=0.0): + def f(q, k, v): + return Tensor.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + dropout_p=dropout_p, + ).realize() - def test_cat_dim0_two_vars(self): - def f(a, b): return a.cat(b, dim=0).realize() - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(j, 3) - symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + q = Tensor.rand(2, 1, 4, 8) + k = Tensor.rand(2, i, 4, 8) + v = Tensor.rand(2, i, 4, 8) + symbolic = ( + f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)) + .reshape(2, 4, 1, 8) + .numpy() + ) + expected = f(q, k, v).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_cat_dim1_two_vars(self): - def f(a, b): return a.cat(b, dim=1).realize() - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(3, i) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_attention_training(self): + with Tensor.train(): + self.test_attention(dropout_p=0.0) + with self.assertRaises(AssertionError): + # symbolic shape dropout is not supported + self.test_attention(dropout_p=0.5) - def test_two_vars_plus1_ij(self): - def f(a, b): return (a@b+1).realize() - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(i, 3) - b = Tensor.rand(3, j) - symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_cat_dim0(self): + def f(a, b): + return a.cat(b, dim=0).realize() - def test_two_vars_plus1_ji(self): - # reverse the order of variables - def f(a, b): return (a@b+1).realize() - for i in range(1, 5): - for j in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - vj = Variable("j", 1, 10).bind(j) - a = Tensor.rand(j, 3) - b = Tensor.rand(3, i) - symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() - expected = f(a, b).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(i, 3) + b = Tensor.rand(2, 3) + symbolic = f(a.reshape(vi, 3), b).reshape(i + 2, 3).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_shrink(self): - for i in range(1, 5): - vi = Variable("i", 1, 10).bind(i) - a = Tensor.rand(7, 11) - symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic = symbolic.numpy() - expected = a.shrink(((3,5),(i,i+2))).numpy() - np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_cat_dim1(self): + def f(a, b): + return a.cat(b, dim=1).realize() -if __name__ == '__main__': - unittest.main() \ No newline at end of file + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(3, i) + b = Tensor.rand(3, 2) + symbolic = f(a.reshape(3, vi), b).reshape(3, i + 2).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim0_two_vars(self): + def f(a, b): + return a.cat(b, dim=0).realize() + + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(i, 3) + b = Tensor.rand(j, 3) + symbolic = ( + f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i + j, 3).numpy() + ) + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_cat_dim1_two_vars(self): + def f(a, b): + return a.cat(b, dim=1).realize() + + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(3, i) + b = Tensor.rand(3, j) + symbolic = ( + f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i + j).numpy() + ) + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_two_vars_plus1_ij(self): + def f(a, b): + return (a @ b + 1).realize() + + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(i, 3) + b = Tensor.rand(3, j) + symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_two_vars_plus1_ji(self): + # reverse the order of variables + def f(a, b): + return (a @ b + 1).realize() + + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(j, 3) + b = Tensor.rand(3, i) + symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + def test_shrink(self): + for i in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + a = Tensor.rand(7, 11) + symbolic = a.shrink(((3, 5), (vi, vi + 2))) + symbolic = symbolic.numpy() + expected = a.shrink(((3, 5), (i, i + 2))).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 678a5c8a9..24765551f 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -3,170 +3,203 @@ from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.symbolic import Variable, NumNode from tinygrad.tensor import Tensor + class TestSymbolic(unittest.TestCase): - def test_symbolic_st(self): - x = Variable("x", 1, 100) - st = ShapeTracker.from_shape((x, 3)) - assert st.shape == (x, 3) - assert st.real_strides() == (3, 1) + def test_symbolic_st(self): + x = Variable("x", 1, 100) + st = ShapeTracker.from_shape((x, 3)) + assert st.shape == (x, 3) + assert st.real_strides() == (3, 1) - def test_expr_idxs(self): - x = Variable("x", 1, 100) - st = ShapeTracker.from_shape((x, 3)) - idxs = [Variable("x", 0, 100), Variable("y", 0, 100)] - e1, e2 = st.expr_idxs(idxs) - assert e1.render() == "((x*3)+y)" - assert e2.render() == "1" - st = st.permute((1, 0)) - e1, e2 = st.expr_idxs(idxs) - assert e1.render() == "((y*3)+x)" - assert e2.render() == "1" + def test_expr_idxs(self): + x = Variable("x", 1, 100) + st = ShapeTracker.from_shape((x, 3)) + idxs = [Variable("x", 0, 100), Variable("y", 0, 100)] + e1, e2 = st.expr_idxs(idxs) + assert e1.render() == "((x*3)+y)" + assert e2.render() == "1" + st = st.permute((1, 0)) + e1, e2 = st.expr_idxs(idxs) + assert e1.render() == "((y*3)+x)" + assert e2.render() == "1" - def test_cat_dim0_strides(self): - i = Variable("i", 1, 5).bind(3) - j = Variable("j", 1, 5).bind(3) - k = Variable("k", 1, 5).bind(3) - t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) - st = t.lazydata.st - assert st.shape == (i+j+k, 4) - assert st.real_strides() == (4, 1) - t = Tensor.rand(3, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0) - st = t.lazydata.st - assert st.shape == (2*i+3, 3) - assert st.real_strides() == (3, 1) + def test_cat_dim0_strides(self): + i = Variable("i", 1, 5).bind(3) + j = Variable("j", 1, 5).bind(3) + k = Variable("k", 1, 5).bind(3) + t = ( + Tensor.rand(3, 4) + .reshape(i, 4) + .cat(Tensor.rand(3, 4).reshape(j, 4), dim=0) + .cat(Tensor.rand(3, 4).reshape(k, 4), dim=0) + ) + st = t.lazydata.st + assert st.shape == (i + j + k, 4) + assert st.real_strides() == (4, 1) + t = ( + Tensor.rand(3, 3) + .reshape(i, 3) + .cat(Tensor.rand(3, 3).reshape(i, 3), dim=0) + .cat(Tensor.rand(3, 3), dim=0) + ) + st = t.lazydata.st + assert st.shape == (2 * i + 3, 3) + assert st.real_strides() == (3, 1) + + def test_cat_dim1_strides(self): + i = Variable("i", 1, 5).bind(4) + j = Variable("j", 1, 5).bind(4) + k = Variable("k", 1, 5).bind(4) + t = ( + Tensor.rand(3, 4) + .reshape(3, i) + .cat(Tensor.rand(3, 4).reshape(3, j), dim=1) + .cat(Tensor.rand(3, 4).reshape(3, k), dim=1) + ) + st = t.lazydata.st + assert st.shape == (3, i + j + k) + assert st.real_strides() == (i + j + k, 1) - def test_cat_dim1_strides(self): - i = Variable("i", 1, 5).bind(4) - j = Variable("j", 1, 5).bind(4) - k = Variable("k", 1, 5).bind(4) - t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1) - st = t.lazydata.st - assert st.shape == (3, i+j+k) - assert st.real_strides() == (i+j+k, 1) class TestSymbolicVarVals(unittest.TestCase): - def test_var_vals_empty(self): - assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {} + def test_var_vals_empty(self): + assert ShapeTracker.from_shape((3, 4, 5)).var_vals == {} - def test_var_vals_shape(self): - x = Variable("x", 1, 100).bind(3) - assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3} + def test_var_vals_shape(self): + x = Variable("x", 1, 100).bind(3) + assert ShapeTracker.from_shape((x, 3)).var_vals == {Variable("x", 1, 100): 3} - def test_var_vals_offset(self): - x = Variable("x", 1, 100).bind(3) - st = ShapeTracker.from_shape((4, 3)).shrink(((x, x+1), (0, 3))) - assert st.views[-1].offset == x * 3 - assert st.var_vals == {Variable("x", 1, 100): 3} + def test_var_vals_offset(self): + x = Variable("x", 1, 100).bind(3) + st = ShapeTracker.from_shape((4, 3)).shrink(((x, x + 1), (0, 3))) + assert st.views[-1].offset == x * 3 + assert st.var_vals == {Variable("x", 1, 100): 3} - def test_var_vals_mask(self): - x = Variable("x", 1, 100).bind(3) - view = View.create(shape=(3,4), strides=(4,1), offset=0, mask=((0, x), (0, 4))) - st = ShapeTracker(views=(view,)) - assert st.var_vals == {Variable("x", 1, 100): 3} + def test_var_vals_mask(self): + x = Variable("x", 1, 100).bind(3) + view = View.create( + shape=(3, 4), strides=(4, 1), offset=0, mask=((0, x), (0, 4)) + ) + st = ShapeTracker(views=(view,)) + assert st.var_vals == {Variable("x", 1, 100): 3} - def test_var_vals_complex(self): - x = Variable("x", 1, 100).bind(3) - y = Variable("y", 1, 100).bind(4) - z = Variable("z", 1, 100).bind(5) - st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z+1), (0, 3))) - assert st.views[-1].offset == y * z - assert st.var_vals == {Variable("x", 1, 100): 3, Variable("y", 1, 100):4, Variable("z", 1, 100): 5} + def test_var_vals_complex(self): + x = Variable("x", 1, 100).bind(3) + y = Variable("y", 1, 100).bind(4) + z = Variable("z", 1, 100).bind(5) + st = ShapeTracker.from_shape((x, 5, y)).shrink(((0, x), (z, z + 1), (0, 3))) + assert st.views[-1].offset == y * z + assert st.var_vals == { + Variable("x", 1, 100): 3, + Variable("y", 1, 100): 4, + Variable("z", 1, 100): 5, + } + + def test_shrink_reshape(self): + x = Variable("x", 1, 100).bind(3) + st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x + 3), (3, 7), (2, 5))) + st = st.reshape((3 * 4 * 3,)) + assert st.var_vals == {Variable("x", 1, 100): 3} - def test_shrink_reshape(self): - x = Variable("x", 1, 100).bind(3) - st = ShapeTracker.from_shape((10, 10, 10)).shrink(((x, x+3), (3, 7), (2, 5))) - st = st.reshape((3*4*3,)) - assert st.var_vals == {Variable("x", 1, 100): 3} class TestShapeTrackerUnbind(unittest.TestCase): - def test_view_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(3) - assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4)) + def test_view_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(3) + assert View.create(shape=(bv, 4)).unbind() == View.create(shape=(v, 4)) - def test_reshape_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(3) - t = Tensor.rand(3, 4).reshape(bv, 4) - assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),)) + def test_reshape_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(3) + t = Tensor.rand(3, 4).reshape(bv, 4) + assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(v, 4)),)) + + def test_shrink_unbind(self): + v = Variable("v", 1, 100) + bv = Variable("v", 1, 100).bind(2) + t = Tensor.rand(3, 4).shrink(((bv, bv + 1), (0, 4))) + assert t.lazydata.st.unbind() == ShapeTracker( + (View.create(shape=(1, 4), offset=4 * v),) + ) - def test_shrink_unbind(self): - v = Variable("v", 1, 100) - bv = Variable("v", 1, 100).bind(2) - t = Tensor.rand(3, 4).shrink(((bv, bv+1), (0, 4))) - assert t.lazydata.st.unbind() == ShapeTracker((View.create(shape=(1, 4), offset=4*v),)) class TestSymbolicReshape(unittest.TestCase): - def test_reshape_into_symbols_simple(self): - for i in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = Tensor.rand(i, 6).reshape(vi, 2, 3) - assert t.shape == (vi, 2, 3) + def test_reshape_into_symbols_simple(self): + for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) + t = Tensor.rand(i, 4).reshape(vi, 4) + assert t.shape == (vi, 4) + t = Tensor.rand(i, 6).reshape(vi, 2, 3) + assert t.shape == (vi, 2, 3) - def test_reshape_symbols_reshape_ints(self): - for i in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - t = Tensor.rand(i, 4).reshape(vi, 4) - assert t.shape == (vi, 4) - t = t.reshape(i, 4) - assert t.shape == (i, 4) + def test_reshape_symbols_reshape_ints(self): + for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) + t = Tensor.rand(i, 4).reshape(vi, 4) + assert t.shape == (vi, 4) + t = t.reshape(i, 4) + assert t.shape == (i, 4) - def test_reshape_into_symbols_bad_shape(self): - vi = Variable("i", 1, 10).bind(4) - with self.assertRaises(ValueError): - Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape - with self.assertRaises(AssertionError): - Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node + def test_reshape_into_symbols_bad_shape(self): + vi = Variable("i", 1, 10).bind(4) + with self.assertRaises(ValueError): + Tensor.rand(4, 6).reshape(vi, 6).reshape( + 1, 77 + ) # reshape to a different size new shape through symbolic shape + with self.assertRaises(AssertionError): + Tensor.rand(3, 4).reshape(3, (vi + 1)) # reshape into non-Variable Node + + def test_two_symbol_reshape(self): + for i in range(1, 6): + for j in range(1, 6): + vi = Variable("i", 1, 5).bind(i) + vj = Variable("j", 1, 5).bind(j) + t = Tensor.rand(i, j).reshape(vi, vj) + assert t.shape == (vi, vj) + # NOTE: this is currently not allowed + # t = t.reshape(1, vi*vj) + # assert t.shape == (1, vi*vj) + t = t.reshape(vj, vi) + assert t.shape == (vj, vi) - def test_two_symbol_reshape(self): - for i in range(1, 6): - for j in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - vj = Variable("j", 1, 5).bind(j) - t = Tensor.rand(i, j).reshape(vi, vj) - assert t.shape == (vi, vj) - # NOTE: this is currently not allowed - # t = t.reshape(1, vi*vj) - # assert t.shape == (1, vi*vj) - t = t.reshape(vj, vi) - assert t.shape == (vj, vi) class TestSymbolicExpand(unittest.TestCase): - def test_expand_into_symbols(self): - vi = Variable("i", 1, 5).bind(3) - vj = Variable("j", 1, 5).bind(3) - a = Tensor([[1], [2], [3]]).expand((3, vi)) - assert a.shape == (3, vi) - a = a.reshape(3, vi, 1).expand((3, vi, vj)) - assert a.shape == (3, vi, vj) + def test_expand_into_symbols(self): + vi = Variable("i", 1, 5).bind(3) + vj = Variable("j", 1, 5).bind(3) + a = Tensor([[1], [2], [3]]).expand((3, vi)) + assert a.shape == (3, vi) + a = a.reshape(3, vi, 1).expand((3, vi, vj)) + assert a.shape == (3, vi, vj) + + def test_plus_expands_constant(self): + for i in range(1, 6): + vi = Variable("i", 1, 5).bind(i) + a = Tensor.rand(3, i).reshape(3, vi) + a = a + 1 + assert a.shape == (3, vi) - def test_plus_expands_constant(self): - for i in range(1, 6): - vi = Variable("i", 1, 5).bind(i) - a = Tensor.rand(3, i).reshape(3, vi) - a = a + 1 - assert a.shape == (3, vi) class TestSymbolicShrink(unittest.TestCase): - def test_shrink_symbols(self): - vi = Variable("i", 1, 5) - t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1))) - assert t.shape == (2, 1) + def test_shrink_symbols(self): + vi = Variable("i", 1, 5) + t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi + 1))) + assert t.shape == (2, 1) + class TestSymbolicShapeExpr(unittest.TestCase): - def test_symbolic_expr_idxs(self): - # taken from symbolic shape llama - i = Variable("i", 1, 120) - gidx0 = Variable("gidx0", 0, i) - lidx1 = Variable("lidx1", 0, 7) - idx = (gidx0, lidx1, NumNode(1)) - shape = (i+1, 8, 4) - strides = (1, (i*4)+4, i+1) - st = ShapeTracker((View.create(shape, strides), )) - idx, _valid = st.expr_idxs(idx) - assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)" + def test_symbolic_expr_idxs(self): + # taken from symbolic shape llama + i = Variable("i", 1, 120) + gidx0 = Variable("gidx0", 0, i) + lidx1 = Variable("lidx1", 0, 7) + idx = (gidx0, lidx1, NumNode(1)) + shape = (i + 1, 8, 4) + strides = (1, (i * 4) + 4, i + 1) + st = ShapeTracker((View.create(shape, strides),)) + idx, _valid = st.expr_idxs(idx) + assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)" -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_tensor.py b/test/test_tensor.py index 6e8b4da61..bdacdd9d5 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -6,388 +6,448 @@ from tinygrad.tensor import Tensor, Device from tinygrad.helpers import dtypes, temp from extra.gradcheck import numerical_jacobian, jacobian, gradcheck -x_init = np.random.randn(1,3).astype(np.float32) -U_init = np.random.randn(3,3).astype(np.float32) -V_init = np.random.randn(3,3).astype(np.float32) -W_init = np.random.randn(3,3).astype(np.float32) -m_init = np.random.randn(1,3).astype(np.float32) +x_init = np.random.randn(1, 3).astype(np.float32) +U_init = np.random.randn(3, 3).astype(np.float32) +V_init = np.random.randn(3, 3).astype(np.float32) +W_init = np.random.randn(3, 3).astype(np.float32) +m_init = np.random.randn(1, 3).astype(np.float32) + class TestTinygrad(unittest.TestCase): - def test_zerodim_initialization(self): - a = Tensor(55) - b = Tensor(3.14) + def test_zerodim_initialization(self): + a = Tensor(55) + b = Tensor(3.14) - self.assertEqual(a.shape, ()) - self.assertEqual(b.shape, ()) + self.assertEqual(a.shape, ()) + self.assertEqual(b.shape, ()) - def test_plus_equals(self): - a = Tensor.randn(10,10) - b = Tensor.randn(10,10) - c = a + b - val1 = c.numpy() - a += b - val2 = a.numpy() - np.testing.assert_allclose(val1, val2) + def test_plus_equals(self): + a = Tensor.randn(10, 10) + b = Tensor.randn(10, 10) + c = a + b + val1 = c.numpy() + a += b + val2 = a.numpy() + np.testing.assert_allclose(val1, val2) - def test_backward_pass(self): - def test_tinygrad(): - x = Tensor(x_init, requires_grad=True) - W = Tensor(W_init, requires_grad=True) - m = Tensor(m_init) - out = x.dot(W).relu() - out = out.log_softmax() - out = out.mul(m).add(m).sum() - out.backward() - return out.numpy(), x.grad.numpy(), W.grad.numpy() + def test_backward_pass(self): + def test_tinygrad(): + x = Tensor(x_init, requires_grad=True) + W = Tensor(W_init, requires_grad=True) + m = Tensor(m_init) + out = x.dot(W).relu() + out = out.log_softmax() + out = out.mul(m).add(m).sum() + out.backward() + return out.numpy(), x.grad.numpy(), W.grad.numpy() - def test_pytorch(): - x = torch.tensor(x_init, requires_grad=True) - W = torch.tensor(W_init, requires_grad=True) - m = torch.tensor(m_init) - out = x.matmul(W).relu() - out = torch.nn.functional.log_softmax(out, dim=1) - out = out.mul(m).add(m).sum() - out.backward() - return out.detach().numpy(), x.grad, W.grad + def test_pytorch(): + x = torch.tensor(x_init, requires_grad=True) + W = torch.tensor(W_init, requires_grad=True) + m = torch.tensor(m_init) + out = x.matmul(W).relu() + out = torch.nn.functional.log_softmax(out, dim=1) + out = out.mul(m).add(m).sum() + out.backward() + return out.detach().numpy(), x.grad, W.grad - for x,y in zip(test_tinygrad(), test_pytorch()): - np.testing.assert_allclose(x, y, atol=1e-5) + for x, y in zip(test_tinygrad(), test_pytorch()): + np.testing.assert_allclose(x, y, atol=1e-5) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461 - def test_backward_pass_diamond_model(self): - def test_tinygrad(): - u = Tensor(U_init, requires_grad=True) - v = Tensor(V_init, requires_grad=True) - w = Tensor(W_init, requires_grad=True) - x = u.mul(v).relu() - y = u.mul(w).relu() - out = x.add(y).mul(y).relu() - out = out.log_softmax() - out = out.sum() - out.backward() - return out.numpy(), u.grad.numpy(), v.grad.numpy(), w.grad.numpy() + @unittest.skipIf( + Device.DEFAULT == "WEBGPU", + "this test uses more than 8 bufs which breaks webgpu", + ) # TODO: remove after #1461 + def test_backward_pass_diamond_model(self): + def test_tinygrad(): + u = Tensor(U_init, requires_grad=True) + v = Tensor(V_init, requires_grad=True) + w = Tensor(W_init, requires_grad=True) + x = u.mul(v).relu() + y = u.mul(w).relu() + out = x.add(y).mul(y).relu() + out = out.log_softmax() + out = out.sum() + out.backward() + return out.numpy(), u.grad.numpy(), v.grad.numpy(), w.grad.numpy() - def test_pytorch(): - u = torch.tensor(U_init, requires_grad=True) - v = torch.tensor(V_init, requires_grad=True) - w = torch.tensor(W_init, requires_grad=True) - x = u.mul(v).relu() - y = u.mul(w).relu() - out = x.add(y).mul(y).relu() - out = torch.nn.functional.log_softmax(out, dim=1) - out = out.sum() - out.backward() - return out.detach().numpy(), u.grad, v.grad, w.grad + def test_pytorch(): + u = torch.tensor(U_init, requires_grad=True) + v = torch.tensor(V_init, requires_grad=True) + w = torch.tensor(W_init, requires_grad=True) + x = u.mul(v).relu() + y = u.mul(w).relu() + out = x.add(y).mul(y).relu() + out = torch.nn.functional.log_softmax(out, dim=1) + out = out.sum() + out.backward() + return out.detach().numpy(), u.grad, v.grad, w.grad - for x,y in zip(test_tinygrad(), test_pytorch()): - np.testing.assert_allclose(x, y, atol=1e-5) + for x, y in zip(test_tinygrad(), test_pytorch()): + np.testing.assert_allclose(x, y, atol=1e-5) - def test_nograd(self): - x = Tensor(x_init, requires_grad=False) - m = Tensor(m_init, requires_grad=False) - W = Tensor(W_init, requires_grad=True) - tmp = x.mul(m) - mm = tmp.matmul(W) - out = mm.relu() - out = out.sum() - out.backward() - assert x.grad is None - assert m.grad is None - assert tmp.grad is None - assert mm.grad is not None - assert W.grad is not None + def test_nograd(self): + x = Tensor(x_init, requires_grad=False) + m = Tensor(m_init, requires_grad=False) + W = Tensor(W_init, requires_grad=True) + tmp = x.mul(m) + mm = tmp.matmul(W) + out = mm.relu() + out = out.sum() + out.backward() + assert x.grad is None + assert m.grad is None + assert tmp.grad is None + assert mm.grad is not None + assert W.grad is not None - def test_dropout(self): - with Tensor.train(): - n, rate = 1_000_000, 0.1 - w = Tensor.ones(n).dropout(rate) - non_zeros = np.count_nonzero(w.numpy()) - expected = n * (1 - rate) - np.testing.assert_allclose(non_zeros, expected, rtol=2e-3) + def test_dropout(self): + with Tensor.train(): + n, rate = 1_000_000, 0.1 + w = Tensor.ones(n).dropout(rate) + non_zeros = np.count_nonzero(w.numpy()) + expected = n * (1 - rate) + np.testing.assert_allclose(non_zeros, expected, rtol=2e-3) - def test_jacobian(self): - W = np.random.RandomState(42069).random((10, 5)).astype(np.float32) - x = np.random.RandomState(69420).random((1, 10)).astype(np.float32) + def test_jacobian(self): + W = np.random.RandomState(42069).random((10, 5)).astype(np.float32) + x = np.random.RandomState(69420).random((1, 10)).astype(np.float32) - torch_x = torch.tensor(x, requires_grad=True) - torch_W = torch.tensor(W, requires_grad=True) - def torch_func(x): return torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1) - PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy() + torch_x = torch.tensor(x, requires_grad=True) + torch_W = torch.tensor(W, requires_grad=True) - tiny_x = Tensor(x, requires_grad=True) - tiny_W = Tensor(W, requires_grad=True) - def tiny_func(x): return x.dot(tiny_W).relu().log_softmax() - J = jacobian(tiny_func, tiny_x) - NJ = numerical_jacobian(tiny_func, tiny_x) + def torch_func(x): + return torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1) - np.testing.assert_allclose(PJ, J, atol = 1e-5) - np.testing.assert_allclose(PJ, NJ, atol = 1e-3) + PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy() - def test_gradcheck(self): - W = np.random.RandomState(1337).random((10, 5)).astype(np.float32) - x = np.random.RandomState(7331).random((1, 10)).astype(np.float32) + tiny_x = Tensor(x, requires_grad=True) + tiny_W = Tensor(W, requires_grad=True) - tiny_x = Tensor(x, requires_grad=True) - tiny_W = Tensor(W, requires_grad=True) - def tiny_func(x): return x.dot(tiny_W).relu().log_softmax() + def tiny_func(x): + return x.dot(tiny_W).relu().log_softmax() - self.assertTrue(gradcheck(tiny_func, tiny_x, eps = 1e-3)) + J = jacobian(tiny_func, tiny_x) + NJ = numerical_jacobian(tiny_func, tiny_x) - # coarse approx. since a "big" eps and the non-linearities of the model - self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 1e-5)) + np.testing.assert_allclose(PJ, J, atol=1e-5) + np.testing.assert_allclose(PJ, NJ, atol=1e-3) - def test_random_fns_are_deterministic_with_seed(self): - for random_fn in [Tensor.randn, Tensor.normal, Tensor.uniform, Tensor.scaled_uniform, Tensor.glorot_uniform, Tensor.kaiming_normal]: - with self.subTest(msg=f"Tensor.{random_fn.__name__}"): - Tensor.manual_seed(1337) - a = random_fn(10,10).realize() - Tensor.manual_seed(1337) - b = random_fn(10,10).realize() - np.testing.assert_allclose(a.numpy(), b.numpy()) + def test_gradcheck(self): + W = np.random.RandomState(1337).random((10, 5)).astype(np.float32) + x = np.random.RandomState(7331).random((1, 10)).astype(np.float32) - def test_randn_isnt_inf_on_zero(self): - # simulate failure case of rand handing a zero to randn - original_rand, Tensor.rand = Tensor.rand, Tensor.zeros - try: self.assertNotIn(np.inf, Tensor.randn(16).numpy()) - except: raise - finally: Tensor.rand = original_rand + tiny_x = Tensor(x, requires_grad=True) + tiny_W = Tensor(W, requires_grad=True) - def test_zeros_like_has_same_dtype(self): - for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]: - a = Tensor([1, 2, 3], dtype=datatype) - b = Tensor.zeros_like(a) - assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}" - assert a.shape == b.shape, f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}" + def tiny_func(x): + return x.dot(tiny_W).relu().log_softmax() - a = Tensor([1, 2, 3]) - b = Tensor.zeros_like(a, dtype=dtypes.int8) - assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char" - assert a.shape == b.shape, f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}" + self.assertTrue(gradcheck(tiny_func, tiny_x, eps=1e-3)) - def test_ones_like_has_same_dtype_and_shape(self): - for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]: - a = Tensor([1, 2, 3], dtype=datatype) - b = Tensor.ones_like(a) - assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}" - assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" + # coarse approx. since a "big" eps and the non-linearities of the model + self.assertFalse(gradcheck(tiny_func, tiny_x, eps=1e-5)) - a = Tensor([1, 2, 3]) - b = Tensor.ones_like(a, dtype=dtypes.int8) - assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char" - assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" + def test_random_fns_are_deterministic_with_seed(self): + for random_fn in [ + Tensor.randn, + Tensor.normal, + Tensor.uniform, + Tensor.scaled_uniform, + Tensor.glorot_uniform, + Tensor.kaiming_normal, + ]: + with self.subTest(msg=f"Tensor.{random_fn.__name__}"): + Tensor.manual_seed(1337) + a = random_fn(10, 10).realize() + Tensor.manual_seed(1337) + b = random_fn(10, 10).realize() + np.testing.assert_allclose(a.numpy(), b.numpy()) - def test_ndim(self): - assert Tensor(1).ndim == 0 - assert Tensor.randn(1).ndim == 1 - assert Tensor.randn(2,2,2).ndim == 3 - assert Tensor.randn(1,1,1,1,1,1).ndim == 6 + def test_randn_isnt_inf_on_zero(self): + # simulate failure case of rand handing a zero to randn + original_rand, Tensor.rand = Tensor.rand, Tensor.zeros + try: + self.assertNotIn(np.inf, Tensor.randn(16).numpy()) + except: + raise + finally: + Tensor.rand = original_rand - def test_argfix(self): - self.assertEqual(Tensor.zeros().shape, ()) - self.assertEqual(Tensor.ones().shape, ()) + def test_zeros_like_has_same_dtype(self): + for datatype in [ + dtypes.float16, + dtypes.float32, + dtypes.int8, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + ]: + a = Tensor([1, 2, 3], dtype=datatype) + b = Tensor.zeros_like(a) + assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}" + assert ( + a.shape == b.shape + ), f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}" - self.assertEqual(Tensor.zeros([]).shape, ()) - self.assertEqual(Tensor.ones([]).shape, ()) + a = Tensor([1, 2, 3]) + b = Tensor.zeros_like(a, dtype=dtypes.int8) + assert ( + a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8 + ), "a.dtype should be float and b.dtype should be char" + assert ( + a.shape == b.shape + ), f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}" - self.assertEqual(Tensor.zeros(tuple()).shape, ()) - self.assertEqual(Tensor.ones(tuple()).shape, ()) + def test_ones_like_has_same_dtype_and_shape(self): + for datatype in [ + dtypes.float16, + dtypes.float32, + dtypes.int8, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + ]: + a = Tensor([1, 2, 3], dtype=datatype) + b = Tensor.ones_like(a) + assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}" + assert ( + a.shape == b.shape + ), f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" - self.assertEqual(Tensor.zeros(1).shape, (1,)) - self.assertEqual(Tensor.ones(1).shape, (1,)) + a = Tensor([1, 2, 3]) + b = Tensor.ones_like(a, dtype=dtypes.int8) + assert ( + a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8 + ), "a.dtype should be float and b.dtype should be char" + assert ( + a.shape == b.shape + ), f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}" - self.assertEqual(Tensor.zeros(1,10,20).shape, (1,10,20)) - self.assertEqual(Tensor.ones(1,10,20).shape, (1,10,20)) + def test_ndim(self): + assert Tensor(1).ndim == 0 + assert Tensor.randn(1).ndim == 1 + assert Tensor.randn(2, 2, 2).ndim == 3 + assert Tensor.randn(1, 1, 1, 1, 1, 1).ndim == 6 - self.assertEqual(Tensor.zeros([1]).shape, (1,)) - self.assertEqual(Tensor.ones([1]).shape, (1,)) + def test_argfix(self): + self.assertEqual(Tensor.zeros().shape, ()) + self.assertEqual(Tensor.ones().shape, ()) - self.assertEqual(Tensor.zeros([10,20,40]).shape, (10,20,40)) - self.assertEqual(Tensor.ones([10,20,40]).shape, (10,20,40)) + self.assertEqual(Tensor.zeros([]).shape, ()) + self.assertEqual(Tensor.ones([]).shape, ()) - self.assertEqual(Tensor.rand(1,10,20).shape, (1,10,20)) - self.assertEqual(Tensor.rand((10,20,40)).shape, (10,20,40)) + self.assertEqual(Tensor.zeros(tuple()).shape, ()) + self.assertEqual(Tensor.ones(tuple()).shape, ()) - self.assertEqual(Tensor.empty(1,10,20).shape, (1,10,20)) - self.assertEqual(Tensor.empty((10,20,40)).shape, (10,20,40)) + self.assertEqual(Tensor.zeros(1).shape, (1,)) + self.assertEqual(Tensor.ones(1).shape, (1,)) - def test_numel(self): - assert Tensor.randn(10, 10).numel() == 100 - assert Tensor.randn(1,2,5).numel() == 10 - assert Tensor.randn(1,1,1,1,1,1).numel() == 1 - assert Tensor([]).numel() == 0 - assert Tensor.randn(1,0,2,5).numel() == 0 + self.assertEqual(Tensor.zeros(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.ones(1, 10, 20).shape, (1, 10, 20)) - def test_element_size(self): - for _, dtype in dtypes.fields().items(): - assert dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size(), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}" + self.assertEqual(Tensor.zeros([1]).shape, (1,)) + self.assertEqual(Tensor.ones([1]).shape, (1,)) - def test_deepwalk_ctx_check(self): - layer = Tensor.uniform(1, 1, requires_grad=True) - x = Tensor.randn(1, 1, 1) - x.dot(layer).mean().backward() - x = Tensor.randn(1, 1, 1) - x.dot(layer).mean().backward() + self.assertEqual(Tensor.zeros([10, 20, 40]).shape, (10, 20, 40)) + self.assertEqual(Tensor.ones([10, 20, 40]).shape, (10, 20, 40)) - def test_zerosized_tensors(self): - np.testing.assert_equal(Tensor([]).numpy(), np.array([])) - np.testing.assert_equal(Tensor(None).numpy(), np.array([])) + self.assertEqual(Tensor.rand(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.rand((10, 20, 40)).shape, (10, 20, 40)) - def test_tensor_ndarray_dtype(self): - arr = np.array([1]) # where dtype is implicitly int64 - assert Tensor(arr).dtype == dtypes.int64 - assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype - assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else + self.assertEqual(Tensor.empty(1, 10, 20).shape, (1, 10, 20)) + self.assertEqual(Tensor.empty((10, 20, 40)).shape, (10, 20, 40)) - def test_tensor_list_dtype(self): - arr = [1] - assert Tensor(arr).dtype == Tensor.default_type - assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 - assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + def test_numel(self): + assert Tensor.randn(10, 10).numel() == 100 + assert Tensor.randn(1, 2, 5).numel() == 10 + assert Tensor.randn(1, 1, 1, 1, 1, 1).numel() == 1 + assert Tensor([]).numel() == 0 + assert Tensor.randn(1, 0, 2, 5).numel() == 0 - def test_tensor_copy(self): - x = copy.deepcopy(Tensor.ones((3,3,3))) - np.testing.assert_allclose(x.numpy(), np.ones((3,3,3))) + def test_element_size(self): + for _, dtype in dtypes.fields().items(): + assert ( + dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size() + ), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}" - def test_copy_from_disk(self): - t = Tensor.randn(30, device="CPU").to(f"disk:{temp('test_copy_from_disk')}") - a = t[10:20] - dev = a.to(Device.DEFAULT) - np.testing.assert_allclose(a.numpy(), dev.numpy()) + def test_deepwalk_ctx_check(self): + layer = Tensor.uniform(1, 1, requires_grad=True) + x = Tensor.randn(1, 1, 1) + x.dot(layer).mean().backward() + x = Tensor.randn(1, 1, 1) + x.dot(layer).mean().backward() + + def test_zerosized_tensors(self): + np.testing.assert_equal(Tensor([]).numpy(), np.array([])) + np.testing.assert_equal(Tensor(None).numpy(), np.array([])) + + def test_tensor_ndarray_dtype(self): + arr = np.array([1]) # where dtype is implicitly int64 + assert Tensor(arr).dtype == dtypes.int64 + assert ( + Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + ) # check if ndarray correctly casts to Tensor dtype + assert ( + Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + ) # check that it works for something else + + def test_tensor_list_dtype(self): + arr = [1] + assert Tensor(arr).dtype == Tensor.default_type + assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 + assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 + + def test_tensor_copy(self): + x = copy.deepcopy(Tensor.ones((3, 3, 3))) + np.testing.assert_allclose(x.numpy(), np.ones((3, 3, 3))) + + def test_copy_from_disk(self): + t = Tensor.randn(30, device="CPU").to(f"disk:{temp('test_copy_from_disk')}") + a = t[10:20] + dev = a.to(Device.DEFAULT) + np.testing.assert_allclose(a.numpy(), dev.numpy()) + + # Regression test for https://github.com/tinygrad/tinygrad/issues/1751 + def test_copy_from_numpy_unaligned(self): + # 2**15 is the minimum for repro + arr = np.random.randn(2**15).astype(dtypes.float.np) + fn = temp("test_copy_from_numpy_unaligned") + with open(fn, "wb") as f: + f.write(b"t" + arr.tobytes()) + with open(fn, "a+b") as f: + memview = memoryview(mmap.mmap(f.fileno(), arr.nbytes + 1)) + ua_arr = np.frombuffer(memview[1:], dtype=arr.dtype, count=arr.shape[0]) + np.testing.assert_allclose(arr, ua_arr) + assert not ua_arr.flags.aligned + # force device copy - to() is opt'd away - Tensor(dev)/1 is ignored + np.testing.assert_allclose(ua_arr, (Tensor(ua_arr) / Tensor(1)).numpy()) - # Regression test for https://github.com/tinygrad/tinygrad/issues/1751 - def test_copy_from_numpy_unaligned(self): - # 2**15 is the minimum for repro - arr = np.random.randn(2**15).astype(dtypes.float.np) - fn = temp('test_copy_from_numpy_unaligned') - with open(fn, 'wb') as f: f.write(b't' + arr.tobytes()) - with open(fn, "a+b") as f: memview = memoryview(mmap.mmap(f.fileno(), arr.nbytes + 1)) - ua_arr = np.frombuffer(memview[1:], dtype=arr.dtype, count=arr.shape[0]) - np.testing.assert_allclose(arr, ua_arr) - assert not ua_arr.flags.aligned - # force device copy - to() is opt'd away - Tensor(dev)/1 is ignored - np.testing.assert_allclose(ua_arr, (Tensor(ua_arr)/Tensor(1)).numpy()) class TestZeroShapeTensor(unittest.TestCase): - def test_shape_stride(self): - t = Tensor.rand(3, 2, 0) - assert t.shape == (3, 2, 0) - # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 1) + def test_shape_stride(self): + t = Tensor.rand(3, 2, 0) + assert t.shape == (3, 2, 0) + # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 + assert t.lazydata.st.real_strides() == (0, 0, 1) - t = Tensor.rand(3, 0, 2) - assert t.shape == (3, 0, 2) - # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 - assert t.lazydata.st.real_strides() == (0, 2, 1) + t = Tensor.rand(3, 0, 2) + assert t.shape == (3, 0, 2) + # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 + assert t.lazydata.st.real_strides() == (0, 2, 1) - t = Tensor.rand(0, 0, 0) - assert t.shape == (0, 0, 0) - # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 - assert t.lazydata.st.real_strides() == (0, 0, 1) + t = Tensor.rand(0, 0, 0) + assert t.shape == (0, 0, 0) + # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 + assert t.lazydata.st.real_strides() == (0, 0, 1) - def test_rand(self): - t = Tensor.rand(3, 2, 0) - assert t.shape == (3, 2, 0) - np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) - t = Tensor.rand(0) - assert t.shape == (0,) - np.testing.assert_equal(t.numpy(), np.zeros((0,))) - t = Tensor.rand(0, 0, 0) - assert t.shape == (0, 0, 0) - np.testing.assert_equal(t.numpy(), np.zeros((0, 0, 0))) + def test_rand(self): + t = Tensor.rand(3, 2, 0) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) + t = Tensor.rand(0) + assert t.shape == (0,) + np.testing.assert_equal(t.numpy(), np.zeros((0,))) + t = Tensor.rand(0, 0, 0) + assert t.shape == (0, 0, 0) + np.testing.assert_equal(t.numpy(), np.zeros((0, 0, 0))) - def test_full(self): - t = Tensor.zeros(3, 2, 0) - assert t.shape == (3, 2, 0) - np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) - t = Tensor.full((3, 2, 0), 12) - assert t.shape == (3, 2, 0) - np.testing.assert_equal(t.numpy(), np.full((3, 2, 0), 12)) + def test_full(self): + t = Tensor.zeros(3, 2, 0) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0))) + t = Tensor.full((3, 2, 0), 12) + assert t.shape == (3, 2, 0) + np.testing.assert_equal(t.numpy(), np.full((3, 2, 0), 12)) - def test_reshape(self): - t = Tensor.zeros(3, 2, 0) - a = t.reshape(7, 0) - assert a.shape == (7, 0) - np.testing.assert_equal(a.numpy(), np.zeros((7, 0))) - with self.assertRaises(AssertionError): - # cannot reshape from size 0 to size 1 - a = t.reshape(()) + def test_reshape(self): + t = Tensor.zeros(3, 2, 0) + a = t.reshape(7, 0) + assert a.shape == (7, 0) + np.testing.assert_equal(a.numpy(), np.zeros((7, 0))) + with self.assertRaises(AssertionError): + # cannot reshape from size 0 to size 1 + a = t.reshape(()) - def test_expand(self): - t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0)) - assert t.shape == (6, 2, 0) - np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12)) + def test_expand(self): + t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0)) + assert t.shape == (6, 2, 0) + np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12)) - def test_pad(self): - t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), 1) - assert t.shape == (3, 2, 2) - np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2))) + def test_pad(self): + t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), 1) + assert t.shape == (3, 2, 2) + np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2))) - if Device.DEFAULT != "TORCH": - # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) - t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1) - assert t.shape == (3, 4, 0) - np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0))) + if Device.DEFAULT != "TORCH": + # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) + t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1) + assert t.shape == (3, 4, 0) + np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0))) - t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1) - assert t.shape == (5, 2, 0) - np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0))) + t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1) + assert t.shape == (5, 2, 0) + np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0))) - def test_shrink_into_zero(self): - t = Tensor.rand(3, 4).realize() - assert t.shrink((None, (2, 2))).realize().shape == (3, 0) - assert t.shrink(((2, 2), None)).realize().shape == (0, 4) - assert t.shrink(((2, 2), (2, 2))).realize().shape == (0, 0) + def test_shrink_into_zero(self): + t = Tensor.rand(3, 4).realize() + assert t.shrink((None, (2, 2))).realize().shape == (3, 0) + assert t.shrink(((2, 2), None)).realize().shape == (0, 4) + assert t.shrink(((2, 2), (2, 2))).realize().shape == (0, 0) - def test_cat(self): - s = Tensor.rand(3, 2, 2) - t = Tensor.rand(3, 2, 0).cat(s, dim=2) - assert t.shape == (3, 2, 2) - np.testing.assert_equal(t.numpy(), s.numpy()) + def test_cat(self): + s = Tensor.rand(3, 2, 2) + t = Tensor.rand(3, 2, 0).cat(s, dim=2) + assert t.shape == (3, 2, 2) + np.testing.assert_equal(t.numpy(), s.numpy()) - if Device.DEFAULT != "TORCH": - # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) - s = Tensor.rand(3, 4, 0) - t = Tensor.rand(3, 2, 0).cat(s, dim=1) - assert t.shape == (3, 6, 0) - np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0))) + if Device.DEFAULT != "TORCH": + # torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0]) + s = Tensor.rand(3, 4, 0) + t = Tensor.rand(3, 2, 0).cat(s, dim=1) + assert t.shape == (3, 6, 0) + np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0))) - def test_elementwise(self): - a = Tensor.rand(3, 2, 0) - a_exp = a.exp() - assert a_exp.shape == (3, 2, 0) - np.testing.assert_equal(a_exp.numpy(), np.exp(a.numpy())) + def test_elementwise(self): + a = Tensor.rand(3, 2, 0) + a_exp = a.exp() + assert a_exp.shape == (3, 2, 0) + np.testing.assert_equal(a_exp.numpy(), np.exp(a.numpy())) - b = Tensor.rand(3, 2, 0) - assert b.shape == (3, 2, 0) - ab = a * b - assert ab.shape == (3, 2, 0) - np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy()) + b = Tensor.rand(3, 2, 0) + assert b.shape == (3, 2, 0) + ab = a * b + assert ab.shape == (3, 2, 0) + np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy()) - mask = (Tensor.rand(3, 2, 0) > 0.5) - assert mask.shape == (3, 2, 0) - c = mask.where(a, b) - assert c.shape == (3, 2, 0) - np.testing.assert_equal(c.numpy(), np.where(mask.numpy(), a.numpy(), b.numpy())) + mask = Tensor.rand(3, 2, 0) > 0.5 + assert mask.shape == (3, 2, 0) + c = mask.where(a, b) + assert c.shape == (3, 2, 0) + np.testing.assert_equal(c.numpy(), np.where(mask.numpy(), a.numpy(), b.numpy())) - def test_reduce_over_non_zero(self): - a = Tensor.ones(3, 2, 0).sum(axis=1) - assert a.shape == (3, 0) - np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=1)) + def test_reduce_over_non_zero(self): + a = Tensor.ones(3, 2, 0).sum(axis=1) + assert a.shape == (3, 0) + np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=1)) - def test_reduce_over_zero(self): - a = Tensor.ones(3, 2, 0).sum(axis=2) - assert a.shape == (3, 2) - np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2)) + def test_reduce_over_zero(self): + a = Tensor.ones(3, 2, 0).sum(axis=2) + assert a.shape == (3, 2) + np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2)) - a = Tensor.ones(3, 2, 0).sum(axis=2, keepdim=True) - assert a.shape == (3, 2, 1) - np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2, keepdims=True)) + a = Tensor.ones(3, 2, 0).sum(axis=2, keepdim=True) + assert a.shape == (3, 2, 1) + np.testing.assert_equal( + a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2, keepdims=True) + ) - def test_reduce_default(self): - np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf")) - np.testing.assert_equal(Tensor([]).min().numpy(), float("inf")) - np.testing.assert_equal(Tensor([]).sum().numpy(), 0) - np.testing.assert_equal(Tensor([]).mean().numpy(), 0) + def test_reduce_default(self): + np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf")) + np.testing.assert_equal(Tensor([]).min().numpy(), float("inf")) + np.testing.assert_equal(Tensor([]).sum().numpy(), 0) + np.testing.assert_equal(Tensor([]).mean().numpy(), 0) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_to_numpy.py b/test/test_to_numpy.py index deaa9fc9f..193780e7a 100644 --- a/test/test_to_numpy.py +++ b/test/test_to_numpy.py @@ -3,15 +3,17 @@ import numpy as np import pickle import unittest -class TestToNumpy(unittest.TestCase): - def test_numpy_is_numpy(self): - output = Tensor.ones((1, 3, 4096)).realize().numpy() - new = np.copy(output) - print(type(new)) - serialized = pickle.dumps(new) - out = pickle.loads(serialized) - assert out.shape == (1,3,4096) - assert (out==1).all() -if __name__ == '__main__': - unittest.main() \ No newline at end of file +class TestToNumpy(unittest.TestCase): + def test_numpy_is_numpy(self): + output = Tensor.ones((1, 3, 4096)).realize().numpy() + new = np.copy(output) + print(type(new)) + serialized = pickle.dumps(new) + out = pickle.loads(serialized) + assert out.shape == (1, 3, 4096) + assert (out == 1).all() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 4d055106e..e2da1d514 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,94 +7,198 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp -def _uops_to_prg(uops): - src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) - return CompiledASTRunner(None, "test", src, - [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, - runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) -def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: - uops.append(UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg)) - return uops[-1] +def _uops_to_prg(uops): + src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) + return CompiledASTRunner( + None, + "test", + src, + [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, + [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, + runtime_args=runtime_args, + ).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) + + +def uop( + uops: List[UOp], + uop: UOps, + dtype: Optional[DType], + vin: Tuple[UOp, ...], + arg: Any = None, +) -> UOp: + uops.append( + UOp(uop, dtype if arg != BinaryOps.CMPLT else dtypes.bool, tuple(vin), arg) + ) + return uops[-1] + def _test_single_value(vals, op, dtype): - uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype)) - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f'data{i+1}', dtype)) for i in range(len(vals))] - loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals))) - alu = uop(uops, UOps.ALU, dtype, loads, op) - uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) - buf = Buffer(Device.DEFAULT, 1, dtype) - buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a in vals] - prg = _uops_to_prg(uops) - prg.exec([buf]+buf2) - return buf.toCPU()[0] + uops = [] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ("data0", dtype)) + buf_loads = [ + uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f"data{i+1}", dtype)) + for i in range(len(vals)) + ] + loads = ( + uop( + uops, + UOps.LOAD, + dtype, + [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)], + ) + for i in range(len(vals)) + ) + alu = uop(uops, UOps.ALU, dtype, loads, op) + uop( + uops, + UOps.STORE, + None, + (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu), + ) + buf = Buffer(Device.DEFAULT, 1, dtype) + buf2 = [Buffer.fromCPU(Device.DEFAULT, np.array([a], dtype=dtype.np)) for a in vals] + prg = _uops_to_prg(uops) + prg.exec([buf] + buf2) + return buf.toCPU()[0] + def _test_single_value_const(vals, op, dtype): - uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype)) - loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals) - alu = uop(uops, UOps.ALU, dtype, loads, op) - uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) - buf = Buffer(Device.DEFAULT, 1, dtype) - prg = _uops_to_prg(uops) - prg.exec([buf]) - return buf.toCPU()[0] + uops = [] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ("data0", dtype)) + loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals) + alu = uop(uops, UOps.ALU, dtype, loads, op) + uop( + uops, + UOps.STORE, + None, + (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu), + ) + buf = Buffer(Device.DEFAULT, 1, dtype) + prg = _uops_to_prg(uops) + prg.exec([buf]) + return buf.toCPU()[0] + class TestUOps(unittest.TestCase): - def _equal(self, v1, v2): - if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2) + def _equal(self, v1, v2): + if not (math.isnan(v1) and math.isnan(v2)): + self.assertAlmostEqual( + v1, v2, places=5 + ) if v1.dtype != np.bool_ else self.assertEqual(v1, v2) - def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): - for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 0.0, 1.0]: - self._equal(f([a], bop, dt), fxn(a)) + def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): + for f in [_test_single_value, _test_single_value_const]: + for a in [-2.0, 0.0, 1.0]: + self._equal(f([a], bop, dt), fxn(a)) - def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False): - for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 0.0, 1.0]: - for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): - self._equal(f([a,b], bop, dt), fxn(a,b)) + def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False): + for f in [_test_single_value, _test_single_value_const]: + for a in [-2.0, 0.0, 1.0]: + for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]): + self._equal(f([a, b], bop, dt), fxn(a, b)) - def _test_top_fxn(self, bop, fxn, dt=dtypes.float32): - for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 0, 1]: - for b in [-3.0, 3.0]: - for c in [-4.0, 4.0]: - self._equal(f([a,b,c], bop, dt), fxn(a,b,c)) + def _test_top_fxn(self, bop, fxn, dt=dtypes.float32): + for f in [_test_single_value, _test_single_value_const]: + for a in [-2.0, 0, 1]: + for b in [-3.0, 3.0]: + for c in [-4.0, 4.0]: + self._equal(f([a, b, c], bop, dt), fxn(a, b, c)) -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") + +@unittest.skipIf( + not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends" +) class TestFloatUOps(TestUOps): - def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a) - def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) - def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) - def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) - def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) - # this is not on most backends - #def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf')) + def test_neg(self): + self._test_uop_fxn(UnaryOps.NEG, lambda a: -a) - def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) - def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) - def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) - def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf')) - def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a 0 else float("-inf" if a == 0 else "nan"), + ) + + def test_sin(self): + self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) + + def test_sqrt(self): + self._test_uop_fxn( + UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float("nan") + ) + + # this is not on most backends + # def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf')) + + def test_add(self): + self._test_bop_fxn(BinaryOps.ADD, lambda a, b: a + b) + + def test_sub(self): + self._test_bop_fxn(BinaryOps.SUB, lambda a, b: a - b) + + def test_mul(self): + self._test_bop_fxn(BinaryOps.MUL, lambda a, b: a * b) + + def test_div(self): + self._test_bop_fxn( + BinaryOps.DIV, lambda a, b: a / b if b != 0 else a * float("inf") + ) + + def test_max(self): + self._test_bop_fxn(BinaryOps.MAX, lambda a, b: max(a, b)) + + def test_cmplt(self): + self._test_bop_fxn(BinaryOps.CMPLT, lambda a, b: a < b) + + # MOD isn't tested on floats + + def test_mulacc(self): + self._test_top_fxn(TernaryOps.MULACC, lambda a, b, c: (a * b) + c) + + def test_where(self): + self._test_top_fxn(TernaryOps.WHERE, lambda a, b, c: b if a != 0 else c) - def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c) - def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c) # TODO: fix this on all the backends -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some") +@unittest.skipIf( + not isinstance(Device[Device.DEFAULT], Compiled) or getenv("ARM64", False), + "only test for compiled backends, broken on some", +) class TestNonFloatUOps(TestUOps): - def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32) - def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32) - def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32) - def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32) - def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True) - def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True) - def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed @unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests") class TestRawDiskBuffer(unittest.TestCase): - def test_readinto_read_speed(self): - tst = np.empty(test_size, np.uint8) - with open(test_fn, "rb") as f: - with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"): - f.readinto(tst) + def test_readinto_read_speed(self): + tst = np.empty(test_size, np.uint8) + with open(test_fn, "rb") as f: + with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"): + f.readinto(tst) + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype") class TestSafetensors(unittest.TestCase): - def test_real_safetensors(self): - import torch - from safetensors.torch import save_file - torch.manual_seed(1337) - tensors = { - "weight1": torch.randn((16, 16)), - "weight2": torch.arange(0, 17, dtype=torch.uint8), - "weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1), - "weight4": torch.arange(0, 2, dtype=torch.uint8), - } - save_file(tensors, temp("model.safetensors")) + def test_real_safetensors(self): + import torch + from safetensors.torch import save_file - ret = safe_load(temp("model.safetensors")) - for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy()) - safe_save(ret, temp("model.safetensors_alt")) - with open(temp("model.safetensors"), "rb") as f: - with open(temp("model.safetensors_alt"), "rb") as g: - assert f.read() == g.read() - ret2 = safe_load(temp("model.safetensors_alt")) - for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy()) + torch.manual_seed(1337) + tensors = { + "weight1": torch.randn((16, 16)), + "weight2": torch.arange(0, 17, dtype=torch.uint8), + "weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17, 1, 1), + "weight4": torch.arange(0, 2, dtype=torch.uint8), + } + save_file(tensors, temp("model.safetensors")) - def test_efficientnet_safetensors(self): - from extra.models.efficientnet import EfficientNet - model = EfficientNet(0) - state_dict = get_state_dict(model) - safe_save(state_dict, temp("eff0")) - state_dict_loaded = safe_load(temp("eff0")) - assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys())) - for k,v in state_dict.items(): - np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy()) + ret = safe_load(temp("model.safetensors")) + for k, v in tensors.items(): + np.testing.assert_array_equal(ret[k].numpy(), v.numpy()) + safe_save(ret, temp("model.safetensors_alt")) + with open(temp("model.safetensors"), "rb") as f: + with open(temp("model.safetensors_alt"), "rb") as g: + assert f.read() == g.read() + ret2 = safe_load(temp("model.safetensors_alt")) + for k, v in tensors.items(): + np.testing.assert_array_equal(ret2[k].numpy(), v.numpy()) - # load with the real safetensors - from safetensors import safe_open - with safe_open(temp("eff0"), framework="pt", device="cpu") as f: - assert sorted(list(f.keys())) == sorted(list(state_dict.keys())) - for k in f.keys(): - np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy()) + def test_efficientnet_safetensors(self): + from extra.models.efficientnet import EfficientNet - def test_huggingface_enet_safetensors(self): - # test a real file - fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") - state_dict = safe_load(fn) - assert len(state_dict.keys()) == 244 - assert 'blocks.2.2.se.conv_reduce.weight' in state_dict - assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570 - assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570 + model = EfficientNet(0) + state_dict = get_state_dict(model) + safe_save(state_dict, temp("eff0")) + state_dict_loaded = safe_load(temp("eff0")) + assert sorted(list(state_dict_loaded.keys())) == sorted(list(state_dict.keys())) + for k, v in state_dict.items(): + np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy()) + + # load with the real safetensors + from safetensors import safe_open + + with safe_open(temp("eff0"), framework="pt", device="cpu") as f: + assert sorted(list(f.keys())) == sorted(list(state_dict.keys())) + for k in f.keys(): + np.testing.assert_array_equal( + f.get_tensor(k).numpy(), state_dict[k].numpy() + ) + + def test_huggingface_enet_safetensors(self): + # test a real file + fn = fetch( + "https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors" + ) + state_dict = safe_load(fn) + assert len(state_dict.keys()) == 244 + assert "blocks.2.2.se.conv_reduce.weight" in state_dict + assert state_dict["blocks.0.0.bn1.num_batches_tracked"].numpy() == 276570 + assert state_dict["blocks.2.0.bn2.num_batches_tracked"].numpy() == 276570 + + def test_metadata(self): + metadata = {"hello": "world"} + safe_save({}, temp("metadata.safetensors"), metadata) + import struct + + with open(temp("metadata.safetensors"), "rb") as f: + dat = f.read() + sz = struct.unpack(">Q", dat[0:8])[0] + import json + + assert json.loads(dat[8 : 8 + sz])["__metadata__"]["hello"] == "world" - def test_metadata(self): - metadata = {"hello": "world"} - safe_save({}, temp('metadata.safetensors'), metadata) - import struct - with open(temp('metadata.safetensors'), 'rb') as f: - dat = f.read() - sz = struct.unpack(">Q", dat[0:8])[0] - import json - assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world' def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None): - if tinygrad_fxn is None: tinygrad_fxn = np_fxn - pathlib.Path(temp(fn)).unlink(missing_ok=True) - tinygrad_tensor = Tensor(data, device="CPU").to(f"disk:{temp(fn)}") - numpy_arr = np.array(data) - tinygrad_fxn(tinygrad_tensor) - np_fxn(numpy_arr) - np.testing.assert_allclose(tinygrad_tensor.numpy(), numpy_arr) + if tinygrad_fxn is None: + tinygrad_fxn = np_fxn + pathlib.Path(temp(fn)).unlink(missing_ok=True) + tinygrad_tensor = Tensor(data, device="CPU").to(f"disk:{temp(fn)}") + numpy_arr = np.array(data) + tinygrad_fxn(tinygrad_tensor) + np_fxn(numpy_arr) + np.testing.assert_allclose(tinygrad_tensor.numpy(), numpy_arr) + class TestDiskTensor(unittest.TestCase): - def test_empty(self): - pathlib.Path(temp("dt1")).unlink(missing_ok=True) - Tensor.empty(100, 100, device=f"disk:{temp('dt1')}") + def test_empty(self): + pathlib.Path(temp("dt1")).unlink(missing_ok=True) + Tensor.empty(100, 100, device=f"disk:{temp('dt1')}") - def test_write_ones(self): - pathlib.Path(temp("dt2")).unlink(missing_ok=True) + def test_write_ones(self): + pathlib.Path(temp("dt2")).unlink(missing_ok=True) - out = Tensor.ones(10, 10, device="CPU") - outdisk = out.to(f"disk:{temp('dt2')}") - print(outdisk) - outdisk.realize() - del out, outdisk + out = Tensor.ones(10, 10, device="CPU") + outdisk = out.to(f"disk:{temp('dt2')}") + print(outdisk) + outdisk.realize() + del out, outdisk - # test file - with open(temp("dt2"), "rb") as f: - assert f.read() == b"\x00\x00\x80\x3F" * 100 + # test file + with open(temp("dt2"), "rb") as f: + assert f.read() == b"\x00\x00\x80\x3F" * 100 - # test load alt - reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}") - out = reloaded.numpy() - assert np.all(out == 1.) + # test load alt + reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt2')}") + out = reloaded.numpy() + assert np.all(out == 1.0) - def test_assign_slice(self): - def assign(x,s,y): x[s] = y - helper_test_disk_tensor("dt3", [0,1,2,3], lambda x: assign(x, slice(0,2), [13, 12])) - helper_test_disk_tensor("dt4", [[0,1,2,3],[4,5,6,7]], lambda x: assign(x, slice(0,1), [[13, 12, 11, 10]])) + def test_assign_slice(self): + def assign(x, s, y): + x[s] = y + + helper_test_disk_tensor( + "dt3", [0, 1, 2, 3], lambda x: assign(x, slice(0, 2), [13, 12]) + ) + helper_test_disk_tensor( + "dt4", + [[0, 1, 2, 3], [4, 5, 6, 7]], + lambda x: assign(x, slice(0, 1), [[13, 12, 11, 10]]), + ) + + def test_reshape(self): + helper_test_disk_tensor("dt5", [1, 2, 3, 4, 5], lambda x: x.reshape((1, 5))) + helper_test_disk_tensor("dt6", [1, 2, 3, 4], lambda x: x.reshape((2, 2))) - def test_reshape(self): - helper_test_disk_tensor("dt5", [1,2,3,4,5], lambda x: x.reshape((1,5))) - helper_test_disk_tensor("dt6", [1,2,3,4], lambda x: x.reshape((2,2))) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/unit/test_flopcounter.py b/test/unit/test_flopcounter.py index ac5bdab95..b854ffe0d 100644 --- a/test/unit/test_flopcounter.py +++ b/test/unit/test_flopcounter.py @@ -1,44 +1,131 @@ #!/usr/bin/env python import unittest -from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer +from tinygrad.ops import ( + LazyOp, + BinaryOps, + ReduceOps, + get_lazyop_info, + BufferOps, + MemBuffer, +) from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import dtypes + class TestFlopCounter(unittest.TestCase): - def setUp(self): - self.buf0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,)))) - self.buf1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,)))) + def setUp(self): + self.buf0 = LazyOp( + BufferOps.LOAD, + (), + MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))), + ) + self.buf1 = LazyOp( + BufferOps.LOAD, + (), + MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))), + ) - def test_flops_add(self): - op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) - info = get_lazyop_info(op0) - self.assertEqual(info.flops, 4) + def test_flops_add(self): + op0 = LazyOp( + BinaryOps.ADD, + ( + self.buf0, + self.buf1, + ), + None, + ) + info = get_lazyop_info(op0) + self.assertEqual(info.flops, 4) - def test_flops_add_twice(self): - op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) - op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None) - info = get_lazyop_info(op1) - self.assertEqual(info.flops, 8) + def test_flops_add_twice(self): + op0 = LazyOp( + BinaryOps.ADD, + ( + self.buf0, + self.buf1, + ), + None, + ) + op1 = LazyOp( + BinaryOps.ADD, + ( + op0, + self.buf1, + ), + None, + ) + info = get_lazyop_info(op1) + self.assertEqual(info.flops, 8) - def test_flops_add_self(self): - op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) - op1 = LazyOp(BinaryOps.ADD, (op0,op0,), None) - info = get_lazyop_info(op1) - self.assertEqual(info.flops, 8) + def test_flops_add_self(self): + op0 = LazyOp( + BinaryOps.ADD, + ( + self.buf0, + self.buf1, + ), + None, + ) + op1 = LazyOp( + BinaryOps.ADD, + ( + op0, + op0, + ), + None, + ) + info = get_lazyop_info(op1) + self.assertEqual(info.flops, 8) - def test_flops_add_roundabout_self(self): - op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) - op1 = LazyOp(BinaryOps.ADD, (op0,self.buf1,), None) - op2 = LazyOp(BinaryOps.ADD, (op0,op1,), None) - info = get_lazyop_info(op2) - self.assertEqual(info.flops, 12) + def test_flops_add_roundabout_self(self): + op0 = LazyOp( + BinaryOps.ADD, + ( + self.buf0, + self.buf1, + ), + None, + ) + op1 = LazyOp( + BinaryOps.ADD, + ( + op0, + self.buf1, + ), + None, + ) + op2 = LazyOp( + BinaryOps.ADD, + ( + op0, + op1, + ), + None, + ) + info = get_lazyop_info(op2) + self.assertEqual(info.flops, 12) - def test_flops_red(self): - op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None) - op1 = LazyOp(ReduceOps.SUM, (op0,), (1,)) - op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None) - info = get_lazyop_info(op2) - self.assertEqual(info.flops, 9) + def test_flops_red(self): + op0 = LazyOp( + BinaryOps.MUL, + ( + self.buf0, + self.buf1, + ), + None, + ) + op1 = LazyOp(ReduceOps.SUM, (op0,), (1,)) + op2 = LazyOp( + BinaryOps.ADD, + ( + op1, + op1, + ), + None, + ) + info = get_lazyop_info(op2) + self.assertEqual(info.flops, 9) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 3ddd59565..432e6cb43 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,164 +1,222 @@ import unittest import numpy as np from PIL import Image -from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up, fetch +from tinygrad.helpers import ( + Context, + ContextVar, + DType, + dtypes, + merge_dicts, + strip_parens, + prod, + round_up, + fetch, +) from tinygrad.shape.symbolic import Variable, NumNode VARIABLE = ContextVar("VARIABLE", 0) + class TestContextVars(unittest.TestCase): - # Ensuring that the test does not modify variables outside the tests. - ctx = Context() - def setUp(self): TestContextVars.ctx.__enter__() - def tearDown(self): TestContextVars.ctx.__exit__() + # Ensuring that the test does not modify variables outside the tests. + ctx = Context() - def test_initial_value_is_set(self): - _TMP = ContextVar("_TMP", 5) - self.assertEqual(_TMP.value, 5) + def setUp(self): + TestContextVars.ctx.__enter__() - def test_multiple_creation_ignored(self): - _TMP2 = ContextVar("_TMP2", 1) - _TMP2 = ContextVar("_TMP2", 2) - self.assertEqual(_TMP2.value, 1) + def tearDown(self): + TestContextVars.ctx.__exit__() - def test_new_var_inside_context(self): - # Creating a _new_ variable inside a context should not have any effect on its scope (?) - with Context(VARIABLE=1): - _TMP3 = ContextVar("_TMP3", 1) - _TMP3 = ContextVar("_TMP3", 2) - self.assertEqual(_TMP3.value, 1) + def test_initial_value_is_set(self): + _TMP = ContextVar("_TMP", 5) + self.assertEqual(_TMP.value, 5) - def test_value_accross_modules(self): - # Mocking module import by invoking the code but not in our globals(). - exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used - # It should not matter that the first creation was in another module. - C = ContextVar("C", 0) - self.assertEqual(C.value, 13) + def test_multiple_creation_ignored(self): + _TMP2 = ContextVar("_TMP2", 1) + _TMP2 = ContextVar("_TMP2", 2) + self.assertEqual(_TMP2.value, 1) - def test_assignment_across_modules(self): - B = ContextVar("B", 1) - # local assignment - B.value = 2 - self.assertEqual(B.value, 2) - # Assignment in another module. - exec('from tinygrad.helpers import ContextVar;B = ContextVar("B", 0);B.value = 3;', {}) # pylint:disable=exec-used - # Assignment in another module should affect this one as well. - self.assertEqual(B.value, 3) + def test_new_var_inside_context(self): + # Creating a _new_ variable inside a context should not have any effect on its scope (?) + with Context(VARIABLE=1): + _TMP3 = ContextVar("_TMP3", 1) + _TMP3 = ContextVar("_TMP3", 2) + self.assertEqual(_TMP3.value, 1) - def test_context_assignment(self): - with Context(VARIABLE=1): - self.assertEqual(VARIABLE.value, 1) - self.assertEqual(VARIABLE.value, 0) + def test_value_accross_modules(self): + # Mocking module import by invoking the code but not in our globals(). + exec( + 'from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {} + ) # pylint:disable=exec-used + # It should not matter that the first creation was in another module. + C = ContextVar("C", 0) + self.assertEqual(C.value, 13) - def test_unknown_param_to_context(self): - with self.assertRaises(KeyError): - with Context(SOMETHING_ELSE=1): - pass + def test_assignment_across_modules(self): + B = ContextVar("B", 1) + # local assignment + B.value = 2 + self.assertEqual(B.value, 2) + # Assignment in another module. + exec( + 'from tinygrad.helpers import ContextVar;B = ContextVar("B", 0);B.value = 3;', + {}, + ) # pylint:disable=exec-used + # Assignment in another module should affect this one as well. + self.assertEqual(B.value, 3) - def test_inside_context_assignment(self): - with Context(VARIABLE=4): - # What you can and cannot do inside a context. - # 1. This type of statement has no effect. - VARIABLE = ContextVar("VARIABLE", 0) - self.assertTrue(VARIABLE >= 4, "ContextVars inside contextmanager may not set a new value") + def test_context_assignment(self): + with Context(VARIABLE=1): + self.assertEqual(VARIABLE.value, 1) + self.assertEqual(VARIABLE.value, 0) - # 2. The call syntax however has a local effect. - VARIABLE.value = 13 - self.assertTrue(VARIABLE.value == 13, "Call syntax however works inside a contextmanager.") + def test_unknown_param_to_context(self): + with self.assertRaises(KeyError): + with Context(SOMETHING_ELSE=1): + pass - # Related to 2. above. Note that VARIABLE is back to 0 again as expected. - self.assertEqual(VARIABLE.value, 0) + def test_inside_context_assignment(self): + with Context(VARIABLE=4): + # What you can and cannot do inside a context. + # 1. This type of statement has no effect. + VARIABLE = ContextVar("VARIABLE", 0) + self.assertTrue( + VARIABLE >= 4, + "ContextVars inside contextmanager may not set a new value", + ) - def test_new_var_inside_context_other_module(self): - with Context(VARIABLE=1): - _NEW2 = ContextVar("_NEW2", 0) - _NEW2 = ContextVar("_NEW2", 1) - self.assertEqual(_NEW2.value, 0) + # 2. The call syntax however has a local effect. + VARIABLE.value = 13 + self.assertTrue( + VARIABLE.value == 13, + "Call syntax however works inside a contextmanager.", + ) - code = """\ + # Related to 2. above. Note that VARIABLE is back to 0 again as expected. + self.assertEqual(VARIABLE.value, 0) + + def test_new_var_inside_context_other_module(self): + with Context(VARIABLE=1): + _NEW2 = ContextVar("_NEW2", 0) + _NEW2 = ContextVar("_NEW2", 1) + self.assertEqual(_NEW2.value, 0) + + code = """\ from tinygrad.helpers import Context, ContextVar with Context(VARIABLE=1): _NEW3 = ContextVar("_NEW3", 0)""" - exec(code, {}) # pylint:disable=exec-used - # While _NEW3 was created in an outside scope it should still work the same as above. - _NEW3 = ContextVar("_NEW3", 1) - self.assertEqual(_NEW3.value, 0) + exec(code, {}) # pylint:disable=exec-used + # While _NEW3 was created in an outside scope it should still work the same as above. + _NEW3 = ContextVar("_NEW3", 1) + self.assertEqual(_NEW3.value, 0) - def test_nested_context(self): - with Context(VARIABLE=1): - with Context(VARIABLE=2): - with Context(VARIABLE=3): - self.assertEqual(VARIABLE.value, 3) - self.assertEqual(VARIABLE.value, 2) - self.assertEqual(VARIABLE.value, 1) - self.assertEqual(VARIABLE.value, 0) + def test_nested_context(self): + with Context(VARIABLE=1): + with Context(VARIABLE=2): + with Context(VARIABLE=3): + self.assertEqual(VARIABLE.value, 3) + self.assertEqual(VARIABLE.value, 2) + self.assertEqual(VARIABLE.value, 1) + self.assertEqual(VARIABLE.value, 0) - def test_decorator(self): - @Context(VARIABLE=1, DEBUG=4) - def test(): - self.assertEqual(VARIABLE.value, 1) + def test_decorator(self): + @Context(VARIABLE=1, DEBUG=4) + def test(): + self.assertEqual(VARIABLE.value, 1) - self.assertEqual(VARIABLE.value, 0) - test() - self.assertEqual(VARIABLE.value, 0) + self.assertEqual(VARIABLE.value, 0) + test() + self.assertEqual(VARIABLE.value, 0) + + def test_context_exit_reverts_updated_values(self): + D = ContextVar("D", 1) + D.value = 2 + with Context(D=3): + ... + assert ( + D.value == 2 + ), f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value." - def test_context_exit_reverts_updated_values(self): - D = ContextVar("D", 1) - D.value = 2 - with Context(D=3): - ... - assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value." class TestMergeDicts(unittest.TestCase): - def test_merge_dicts(self): - a = {"a": 1, "b": 2} - b = {"a": 1, "c": 3} - c = {} - d = {"a": 2, "b": 2} - assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3} - assert merge_dicts([a, c]) == a - assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3} - with self.assertRaises(AssertionError): - merge_dicts([a, d]) + def test_merge_dicts(self): + a = {"a": 1, "b": 2} + b = {"a": 1, "c": 3} + c = {} + d = {"a": 2, "b": 2} + assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3} + assert merge_dicts([a, c]) == a + assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3} + with self.assertRaises(AssertionError): + merge_dicts([a, d]) + class TestDtypes(unittest.TestCase): - def test_dtypes_fields(self): - fields = dtypes.fields() - self.assertTrue(all(isinstance(value, DType) for value in fields.values())) - self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None)) + def test_dtypes_fields(self): + fields = dtypes.fields() + self.assertTrue(all(isinstance(value, DType) for value in fields.values())) + self.assertTrue( + all( + issubclass(value.np, np.generic) + for value in fields.values() + if value.np is not None + ) + ) + class TestStripParens(unittest.TestCase): - def test_simple(self): self.assertEqual("1+2", strip_parens("(1+2)")) - def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))")) - def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)")) + def test_simple(self): + self.assertEqual("1+2", strip_parens("(1+2)")) + + def test_nested(self): + self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))")) + + def test_casted_no_strip(self): + self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)")) + class TestProd(unittest.TestCase): - def test_empty(self): self.assertEqual(1, prod(tuple())) - def test_ints(self): self.assertEqual(30, prod((2, 3, 5))) - def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render()) - def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render()) - def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3)))) + def test_empty(self): + self.assertEqual(1, prod(tuple())) + + def test_ints(self): + self.assertEqual(30, prod((2, 3, 5))) + + def test_variable(self): + self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render()) + + def test_variable_order(self): + self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render()) + + def test_num_nodes(self): + self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3)))) + class TestRoundUp(unittest.TestCase): - def test_round_up(self): - self.assertEqual(round_up(-3,4), 0) - self.assertEqual(round_up(-4,4), -4) - self.assertEqual(round_up(6,4), 8) - self.assertEqual(round_up(8,4), 8) - self.assertEqual(round_up(232, 24984), 24984) - self.assertEqual(round_up(24984, 232), 25056) + def test_round_up(self): + self.assertEqual(round_up(-3, 4), 0) + self.assertEqual(round_up(-4, 4), -4) + self.assertEqual(round_up(6, 4), 8) + self.assertEqual(round_up(8, 4), 8) + self.assertEqual(round_up(232, 24984), 24984) + self.assertEqual(round_up(24984, 232), 25056) + class TestFetch(unittest.TestCase): - def test_fetch_bad_http(self): - self.assertRaises(Exception, fetch, 'http://www.google.com/404') + def test_fetch_bad_http(self): + self.assertRaises(Exception, fetch, "http://www.google.com/404") - def test_fetch_small(self): - assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0) + def test_fetch_small(self): + assert len(fetch("https://google.com", allow_caching=False).read_bytes()) > 0 - def test_fetch_img(self): - img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190", allow_caching=False) - with Image.open(img) as pimg: - assert pimg.size == (705, 1024) + def test_fetch_img(self): + img = fetch( + "https://media.istockphoto.com/photos/hen-picture-id831791190", + allow_caching=False, + ) + with Image.open(img) as pimg: + assert pimg.size == (705, 1024) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index e9965ad07..71b8b672a 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -6,803 +6,1001 @@ from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction from tinygrad.shape.symbolic import Variable, NumNode from itertools import product + def shapetracker_getitem(st, val): - locals = {"idx": val, "valid": 1} - idx, valid = st.expr_node() - exec(f"valid={valid.render()};idx={idx.render()}", None, locals) - return locals["idx"] if locals["valid"] else -1 + locals = {"idx": val, "valid": 1} + idx, valid = st.expr_node() + exec(f"valid={valid.render()};idx={idx.render()}", None, locals) + return locals["idx"] if locals["valid"] else -1 + class CheckingShapeTracker: - def __init__(self, shape): - self.st = ShapeTracker.from_shape(shape) - self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape) + def __init__(self, shape): + self.st = ShapeTracker.from_shape(shape) + self.t = np.arange(prod(shape), dtype=np.int32).reshape(shape) - @property - def shape(self): - return self.t.shape + @property + def shape(self): + return self.t.shape - def simplify(self): - self.st = self.st.simplify() - return self + def simplify(self): + self.st = self.st.simplify() + return self - def reshape(self, new_shape): - self.st = self.st.reshape(new_shape) - self.t = self.t.reshape(new_shape) - return self + def reshape(self, new_shape): + self.st = self.st.reshape(new_shape) + self.t = self.t.reshape(new_shape) + return self - def permute(self, axis): - self.st = self.st.permute(axis) - self.t = np.transpose(self.t, axis) - return self + def permute(self, axis): + self.st = self.st.permute(axis) + self.t = np.transpose(self.t, axis) + return self - def expand(self, new_shape): - self.st = self.st.expand(new_shape) - self.t = np.broadcast_to(self.t, new_shape) - return self + def expand(self, new_shape): + self.st = self.st.expand(new_shape) + self.t = np.broadcast_to(self.t, new_shape) + return self - def flip(self, axis): - self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape)))) - self.t = np.flip(self.t, axis) - return self + def flip(self, axis): + self.st = self.st.stride( + tuple(-1 if i in axis else 1 for i in range(len(self.shape))) + ) + self.t = np.flip(self.t, axis) + return self - def shrink(self, arg): - self.st = self.st.shrink(arg) - self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])] - return self + def shrink(self, arg): + self.st = self.st.shrink(arg) + self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])] + return self - def pad(self, arg): - self.st = self.st.pad(arg) - self.t = np.pad(self.t, arg, constant_values=-1) - return self + def pad(self, arg): + self.st = self.st.pad(arg) + self.t = np.pad(self.t, arg, constant_values=-1) + return self - def stride(self, arg): - self.st = self.st.stride(arg) - self.t = self.t[tuple([slice(None, None, x) for x in arg])] - return self + def stride(self, arg): + self.st = self.st.stride(arg) + self.t = self.t[tuple([slice(None, None, x) for x in arg])] + return self - def __getitem__(self, val): - return self.t.flatten()[val] + def __getitem__(self, val): + return self.t.flatten()[val] - @property - def views(self): return self.st.views + @property + def views(self): + return self.st.views - @property - def contiguous(self): return self.st.contiguous + @property + def contiguous(self): + return self.st.contiguous + + def assert_same(self): + x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))] + y = [self[i] for i in range(prod(self.shape))] + idx, valid = self.st.expr_node() + if DEBUG >= 1: + print( + x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st + ) + assert self.st.shape == self.shape + assert x == y, f"mismatch shapetracker:{x} real:{y}" - def assert_same(self): - x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))] - y = [self[i] for i in range(prod(self.shape))] - idx, valid = self.st.expr_node() - if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st) - assert self.st.shape == self.shape - assert x == y, f"mismatch shapetracker:{x} real:{y}" class TestRealIssues(unittest.TestCase): - def test_reshape_doesnt_multiview(self): - self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),)) - self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2)) - assert len(self.st.views) == 1 + def test_reshape_doesnt_multiview(self): + self.st = ShapeTracker( + ( + View.create( + (256, 256, 2, 2, 2, 2, 2, 256, 8, 2), + (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), + 0, + None, + ), + ) + ) + self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2)) + assert len(self.st.views) == 1 + class TestRealDoesntSimplify(unittest.TestCase): - def tearDown(self): - st = self.st.real_strides() - print(st) - self.st = self.st.simplify() - assert len(self.st.views) != 1 - assert None in st + def tearDown(self): + st = self.st.real_strides() + print(st) + self.st = self.st.simplify() + assert len(self.st.views) != 1 + assert None in st - def test_1(self): - self.st = ShapeTracker(( - View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), - View.create((8, 6, 11), (66, 11, 1), 0, None))) - assert self.st.real_strides() == (33, None, 1) + def test_1(self): + self.st = ShapeTracker( + ( + View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), + View.create((8, 6, 11), (66, 11, 1), 0, None), + ) + ) + assert self.st.real_strides() == (33, None, 1) + + def test_2(self): + self.st = ShapeTracker( + ( + View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), + View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None), + ) + ) + assert self.st.real_strides() == (None, 18, -3, -1) - def test_2(self): - self.st = ShapeTracker(( - View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), - View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) - assert self.st.real_strides() == (None, 18, -3, -1) class TestRealStrides(unittest.TestCase): - def test_1(self): - self.st = ShapeTracker(( - View.create((2048,), (1,), 0, ((0, 512),)), - View.create((16, 32, 4), (128, 4, 1), 0, None))) - st = self.st.real_strides() - print(self.st, st) - assert st == (None, 4, 1) + def test_1(self): + self.st = ShapeTracker( + ( + View.create((2048,), (1,), 0, ((0, 512),)), + View.create((16, 32, 4), (128, 4, 1), 0, None), + ) + ) + st = self.st.real_strides() + print(self.st, st) + assert st == (None, 4, 1) + class TestRealSimplifies(unittest.TestCase): - def tearDown(self): - st = self.st.real_strides() - self.st = self.st.simplify() - assert len(self.st.views) == 1 - print(self.st.views[-1].strides, st) - assert self.st.views[-1].strides == st + def tearDown(self): + st = self.st.real_strides() + self.st = self.st.simplify() + assert len(self.st.views) == 1 + print(self.st.views[-1].strides, st) + assert self.st.views[-1].strides == st - def test_1(self): - self.st = ShapeTracker(( - View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None), - View.create((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None))) + def test_1(self): + self.st = ShapeTracker( + ( + View.create((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None), + View.create( + (1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None + ), + ) + ) + + def test_2(self): + self.st = ShapeTracker( + ( + View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None), + View.create( + (8, 1, 6, 10, 28, 3, 2, 1), + (5544, 0, 0, 56, 1, 1848, 672, 0), + 0, + None, + ), + ) + ) - def test_2(self): - self.st = ShapeTracker(( - View.create((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None), - View.create((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None))) class TestIndexExpressions2d(unittest.TestCase): + def setUp(self): + shapes = [ + (30, 5), + (15, 10), + (15, 1), + (5, 10), + (5, 1), + ] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 + offsets = [0, 1, 15, 28, 10000] + self.sts = [ + ShapeTracker((View.create(base_shape, offset=offset),)) + for base_shape in shapes + for offset in offsets + ] + self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets] + self.shapes = [shape for shape in shapes for offset in offsets] + self.node_exprs = [] + self.idxs_exprs = [] - def setUp(self): - shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 - offsets = [0, 1, 15, 28, 10000] - self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets] - self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets] - self.shapes = [shape for shape in shapes for offset in offsets] - self.node_exprs = [] - self.idxs_exprs = [] + def tearDown(self): + for st, offset, shape, node_expr, idxs_expr in zip( + self.sts, self.offset, self.shapes, self.node_exprs, self.idxs_exprs + ): + numel = prod(shape) + assert node_expr(self.default_idx(st.shape)) == st.expr_node()[0] + assert node_expr(self.default_idx(st.shape)) == st.expr_node(None)[0] + assert node_expr(self.default_idx(st.shape)) == st.expr_node("idx")[0] + self.check_bounds(node_expr(self.default_idx(st.shape)), offset, numel) + for idx in [ + (0, numel - 1), + (7, 203), + (2, 5), + (0, 0), + (numel, numel), + (0, numel), + (0, numel + 1), + (numel + 100, numel + 100), + ]: + idx = Variable("idx", idx[0], idx[1]) + assert node_expr(idx) == st.expr_node(idx)[0] + self.check_bounds(node_expr(idx), offset, numel) - def tearDown(self): - for st, offset, shape, node_expr, idxs_expr in zip(self.sts, self.offset, self.shapes, self.node_exprs, self.idxs_exprs): - numel = prod(shape) - assert node_expr(self.default_idx(st.shape)) == st.expr_node()[0] - assert node_expr(self.default_idx(st.shape)) == st.expr_node(None)[0] - assert node_expr(self.default_idx(st.shape)) == st.expr_node('idx')[0] - self.check_bounds(node_expr(self.default_idx(st.shape)), offset, numel) - for idx in [(0, numel-1), (7, 203), (2, 5), (0, 0), (numel, numel), (0, numel), (0, numel+1), (numel+100, numel+100)]: - idx = Variable("idx", idx[0], idx[1]) - assert node_expr(idx) == st.expr_node(idx)[0] - self.check_bounds(node_expr(idx), offset, numel) + assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs()[0] + assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0] + self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel) + idx0s = [ + (0, 0), + (0, min(1, st.shape[0] - 1)), + (0, st.shape[0] - 1), + (min(3, st.shape[0] - 1), min(6, st.shape[0] - 1)), + (st.shape[0] - 1, st.shape[0] - 1), + ] + idx1s = [ + (0, 0), + (0, min(1, st.shape[1] - 1)), + (0, st.shape[1] - 1), + (min(3, st.shape[1] - 1), min(6, st.shape[1] - 1)), + (st.shape[1] - 1, st.shape[1] - 1), + ] + idx2s = ( + [ + (0, 0), + (0, min(1, st.shape[2] - 1)), + (0, st.shape[2] - 1), + (min(3, st.shape[2] - 1), min(6, st.shape[2] - 1)), + (st.shape[2] - 1, st.shape[2] - 1), + ] + if len(st.shape) == 3 + else [None for _ in idx0s] + ) + for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s): + idxs = [ + Variable(f"idx{i}", idx[0], idx[1]) + for i, idx in enumerate((idx0, idx1, idx2)) + if idx is not None + ] + assert idxs_expr(idxs) == st.expr_idxs(idxs)[0] + self.check_bounds(idxs_expr(idxs), offset, numel) - assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs()[0] - assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0] - self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel) - idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)] - idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)] - idx2s = [(0,0), (0, min(1, st.shape[2]-1)), (0, st.shape[2]-1), (min(3, st.shape[2]-1), min(6, st.shape[2]-1)), (st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s] - for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s): - idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None] - assert idxs_expr(idxs) == st.expr_idxs(idxs)[0] - self.check_bounds(idxs_expr(idxs), offset, numel) + def default_idx(self, shape): + return Variable("idx", 0, prod(shape) - 1) - def default_idx(self, shape): - return Variable("idx", 0, prod(shape)-1) + def default_idxs(self, shape): + return [Variable(f"idx{i}", 0, d - 1) for i, d in enumerate(shape)] - def default_idxs(self, shape): - return [Variable(f"idx{i}", 0, d-1) for i,d in enumerate(shape)] + def check_bounds(self, expr, offset, numel): + assert expr.min >= offset + assert expr.max <= offset + numel - 1 - def check_bounds(self, expr, offset, numel): - assert expr.min >= offset - assert expr.max <= offset + numel - 1 + def test_noop(self): + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx % prod(base_shape) + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + * base_shape[1] + + idxs[1] + + offset + ) - def test_noop(self): - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[1] + offset) + def test_permute(self): + new_st = [] + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + st = st.permute((1, 0)) + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx + % base_shape[0] + * base_shape[1] + + idx // base_shape[0] % base_shape[1] + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + + idxs[1] * base_shape[1] + + offset + ) + new_st.append(st) + self.sts = new_st - def test_permute(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + idxs[1]*base_shape[1] + offset) - new_st.append(st) - self.sts = new_st + def test_reshape(self): + new_st = [] + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + st = st.reshape((base_shape[0], 1, base_shape[1])) + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx % prod(base_shape) + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + * base_shape[1] + + idxs[2] + + offset + ) + new_st.append(st) + self.sts = new_st - def test_reshape(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.reshape((base_shape[0], 1, base_shape[1])) - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape) + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) - new_st.append(st) - self.sts = new_st + def test_reshape_expand(self): + new_st = [] + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + st = st.reshape((base_shape[0], 1, base_shape[1])) + st = st.expand((base_shape[0], base_shape[1], base_shape[1])) + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx + // (base_shape[1] * base_shape[1]) + % base_shape[0] + * base_shape[1] + + idx % base_shape[1] + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: idxs[0] + * base_shape[1] + + idxs[2] + + offset + ) + new_st.append(st) + self.sts = new_st - def test_reshape_expand(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.reshape((base_shape[0], 1, base_shape[1])) - st = st.expand((base_shape[0], base_shape[1], base_shape[1])) - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) - new_st.append(st) - self.sts = new_st + def test_permute_reshape_1(self): # This tests multiple views + new_st = [] + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + st = st.permute((1, 0)) + st = st.reshape((base_shape[0] // 5, 1, base_shape[1] * 5)) + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx + % prod(base_shape) + % base_shape[0] + * base_shape[1] + + idx // base_shape[0] % base_shape[1] + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: ( + idxs[0] * (base_shape[1] * 5) + idxs[2] + ) + % base_shape[0] + * base_shape[1] + + (idxs[0] * (base_shape[1] * 5) + idxs[2]) // base_shape[0] + + offset + ) + new_st.append(st) + self.sts = new_st - def test_permute_reshape_1(self): # This tests multiple views - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - st = st.reshape((base_shape[0]//5, 1, base_shape[1]*5)) - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[0]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[0]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) - new_st.append(st) - self.sts = new_st + def test_permute_reshape_2(self): + new_st = [] + for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): + st = st.permute((1, 0)) + st = st.reshape((1, base_shape[0] // 5, base_shape[1] * 5)) + self.node_exprs.append( + lambda idx, base_shape=base_shape, offset=offset: idx + % prod(base_shape) + % base_shape[0] + * base_shape[1] + + idx // base_shape[0] % base_shape[1] + + offset + ) + self.idxs_exprs.append( + lambda idxs, base_shape=base_shape, offset=offset: ( + idxs[1] * (base_shape[1] * 5) + idxs[2] + ) + % base_shape[0] + * base_shape[1] + + (idxs[1] * (base_shape[1] * 5) + idxs[2]) // base_shape[0] + + offset + ) + new_st.append(st) + self.sts = new_st - def test_permute_reshape_2(self): - new_st = [] - for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): - st = st.permute((1, 0)) - st = st.reshape((1, base_shape[0]//5, base_shape[1]*5)) - self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx%prod(base_shape)%base_shape[0]*base_shape[1] + idx//base_shape[0]%base_shape[1] + offset) - self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: (idxs[1]*(base_shape[1]*5)+idxs[2])%base_shape[0]*base_shape[1] + (idxs[1]*(base_shape[1]*5)+idxs[2])//base_shape[0] + offset) - new_st.append(st) - self.sts = new_st + def test_reshaping_splitting(self): + self.st = CheckingShapeTracker((5, 10, 5, 10)) + self.st.permute((1, 0, 3, 2)) + self.st.pad(((0, 0), (0, 5), (0, 0), (0, 5))) + self.st.reshape((10, 2, 5, 10, 2, 5)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshaping_splitting(self): - self.st = CheckingShapeTracker((5,10,5,10)) - self.st.permute((1, 0, 3, 2)) - self.st.pad(((0,0), (0,5), (0,0), (0,5))) - self.st.reshape((10,2,5,10,2,5)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_splitting_1(self): + self.st = CheckingShapeTracker((1, 10, 1)) + self.st.pad(((0, 4), (0, 0), (1, 0))) + self.st.reshape((5, 5, 2, 2)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_splitting_1(self): - self.st = CheckingShapeTracker((1,10,1)) - self.st.pad(((0,4),(0,0),(1,0))) - self.st.reshape((5,5,2,2)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_combining_1(self): + self.st = CheckingShapeTracker((2, 1, 10)) + self.st.pad(((2, 6), (0, 0), (0, 0))) + self.st.reshape((100,)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_combining_1(self): - self.st = CheckingShapeTracker((2,1,10)) - self.st.pad(((2,6), (0,0), (0,0))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_combining_2(self): + self.st = CheckingShapeTracker((1, 1, 5)) + self.st.pad(((3, 6), (0, 0), (0, 5))) + self.st.reshape((100,)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_combining_2(self): - self.st = CheckingShapeTracker((1,1,5)) - self.st.pad(((3,6), (0,0), (0,5))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_combining_3(self): + self.st = CheckingShapeTracker((1, 1, 4)) + self.st.pad(((3, 6), (0, 0), (1, 5))) + self.st.reshape((100,)) + assert len(self.st.views) == 1 + assert self.st.views[0].mask[0] == (31, 35) + self.st.assert_same() - def test_reshape_combining_3(self): - self.st = CheckingShapeTracker((1,1,4)) - self.st.pad(((3,6), (0,0), (1,5))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - assert self.st.views[0].mask[0] == (31, 35) - self.st.assert_same() + def test_reshape_combining_4(self): + self.st = CheckingShapeTracker((1, 1, 5, 5, 1, 1, 5)) + self.st.pad(((3, 6), (0, 0), (0, 5), (0, 0), (3, 6), (0, 0), (0, 5))) + self.st.reshape((100, 5, 100)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_combining_4(self): - self.st = CheckingShapeTracker((1,1,5,5,1,1,5)) - self.st.pad(((3,6), (0,0), (0,5), (0,0), (3,6), (0,0), (0,5))) - self.st.reshape((100,5,100)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_splitting_combining(self): + self.st = CheckingShapeTracker((1, 5, 5)) + self.st.pad(((0, 4), (0, 5), (0, 0))) + self.st.reshape((10, 25)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_splitting_combining(self): - self.st = CheckingShapeTracker((1,5,5)) - self.st.pad(((0,4), (0,5), (0,0))) - self.st.reshape((10,25)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_reshape_only_1s(self): + self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) + self.st.pad(((0, 4), (0, 0), (0, 0), (1, 1), (0, 0), (0, 0), (0, 0), (0, 0))) + self.st.reshape((5, 6, 3, 5)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_reshape_only_1s(self): - self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) - self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0))) - self.st.reshape((5, 6, 3, 5)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_zero_mask_1(self): + self.st = CheckingShapeTracker((1, 3, 2)) + self.st.pad(((0, 0), (0, 3), (0, 0))) + self.st.shrink(((0, 1), (3, 6), (0, 2))) + self.st.reshape((3, 2)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 3, 1, 2, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_zero_mask_1(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,0), (0,3), (0,0))) - self.st.shrink(((0,1), (3,6), (0,2))) - self.st.reshape((3,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_zero_mask_2(self): + self.st = CheckingShapeTracker((1, 3, 2)) + self.st.pad(((0, 2), (0, 3), (0, 0))) + self.st.shrink(((2, 3), (3, 6), (0, 2))) + self.st.reshape((3, 2)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 3, 1, 2, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_zero_mask_2(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,2), (0,3), (0,0))) - self.st.shrink(((2,3), (3,6), (0,2))) - self.st.reshape((3,2)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_expanded_reshaped(self): + self.st = CheckingShapeTracker((1, 3, 2, 1)) + self.st.expand((5, 3, 2, 2)) + self.st.pad(((0, 0), (0, 3), (0, 0), (0, 0))) + self.st.reshape((5, 2, 3, 2, 2)) + assert len(self.st.views) == 1 + self.st.assert_same() - def test_expanded_reshaped(self): - self.st = CheckingShapeTracker((1, 3, 2, 1)) - self.st.expand((5, 3, 2, 2)) - self.st.pad(((0,0), (0,3), (0,0), (0, 0))) - self.st.reshape((5, 2, 3, 2, 2)) - assert len(self.st.views) == 1 - self.st.assert_same() + def test_splitting_big(self): + self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) + self.st.pad(((0, 0), (0, 5), (0, 0), (0, 15), (0, 0))) + self.st.reshape((10, 1, 30)) + self.st.permute((2, 1, 0)) + self.st.reshape((2, 3, 5, 2, 5)) + assert len(self.st.views) == 1 + v = self.st.views[-1] + assert v.strides == (15, 5, 1, 75, 15) and v.mask == ( + (0, 1), + (0, 3), + (0, 5), + (0, 1), + (0, 5), + ) + self.st.assert_same() + + def test_combining_big(self): + self.st = CheckingShapeTracker((1, 3, 1, 5, 3, 1)) + self.st.pad(((0, 0), (2, 2), (0, 0), (0, 0), (0, 0), (0, 0))) + self.st.reshape((1, 1, 1, 105, 1, 1)) + assert len(self.st.views) == 1 + v = self.st.views[-1] + assert ( + v.strides == (0, 0, 0, 1, 0, 0) + and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)) + and v.offset == -30 + ) + self.st.assert_same() - def test_splitting_big(self): - self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) - self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0))) - self.st.reshape((10, 1, 30)) - self.st.permute((2,1,0)) - self.st.reshape((2,3,5,2,5)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) - self.st.assert_same() - def test_combining_big(self): - self.st = CheckingShapeTracker((1,3,1,5,3,1)) - self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0))) - self.st.reshape((1,1,1,105,1,1)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)) and v.offset == -30 - self.st.assert_same() class TestSimplifyingShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((1, 10)) + def setUp(self): + self.st = CheckingShapeTracker((1, 10)) - def tearDown(self): - self.st.assert_same() + def tearDown(self): + self.st.assert_same() - # multiview simplify - def test_expand_contract_simple(self): - self.st = self.st.expand((10, 10)) - self.st = self.st.reshape((100,)) - print(self.st.views) - assert(len(self.st.views) == 2) - self.st = self.st.reshape((10, 10)) - print(self.st.views) + # multiview simplify + def test_expand_contract_simple(self): + self.st = self.st.expand((10, 10)) + self.st = self.st.reshape((100,)) + print(self.st.views) + assert len(self.st.views) == 2 + self.st = self.st.reshape((10, 10)) + print(self.st.views) - self.st = self.st.simplify() - print(self.st.views) - assert(len(self.st.views) == 1) + self.st = self.st.simplify() + print(self.st.views) + assert len(self.st.views) == 1 - # multiview simplify - def test_expand_contract_different_shape(self): - self.st.expand((10, 10)) - self.st.reshape((100,)) - print(self.st.views) - assert(len(self.st.views) == 2) - self.st.reshape((2, 5, 2, 5)) - print(self.st.views) + # multiview simplify + def test_expand_contract_different_shape(self): + self.st.expand((10, 10)) + self.st.reshape((100,)) + print(self.st.views) + assert len(self.st.views) == 2 + self.st.reshape((2, 5, 2, 5)) + print(self.st.views) - self.st = self.st.simplify() - print(self.st.views) - assert(len(self.st.views) == 1) + self.st = self.st.simplify() + print(self.st.views) + assert len(self.st.views) == 1 - # multiview simplify - def test_expand_contract_still_complex(self): - self.st.expand((10, 10)) - self.st.reshape((100,)) - print(self.st.views) - assert(len(self.st.views) == 2) - self.st.reshape((5, 20)) + # multiview simplify + def test_expand_contract_still_complex(self): + self.st.expand((10, 10)) + self.st.reshape((100,)) + print(self.st.views) + assert len(self.st.views) == 2 + self.st.reshape((5, 20)) + + self.st = self.st.simplify() + print(self.st.views) + assert len(self.st.views) == 2 - self.st = self.st.simplify() - print(self.st.views) - assert(len(self.st.views) == 2) # Tensor.zeros(2, 4).permute(1,0).reshape(2, 4) # (d1*4 + d0%4), d1=x//4, d0=x%4 = ((x//4)*4) + (x%4)%4 + class TestComplexShapeTracker(unittest.TestCase): - def test_add_1s(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((1,4,1,4,1)) - assert not self.st.contiguous - self.st.permute((0,3,2,1,4)) - assert self.st.contiguous + def test_add_1s(self): + self.st = CheckingShapeTracker((4, 4)) + self.st.permute((1, 0)) + self.st.reshape((1, 4, 1, 4, 1)) + assert not self.st.contiguous + self.st.permute((0, 3, 2, 1, 4)) + assert self.st.contiguous - def test_permute_1s_simple(self): - self.st = CheckingShapeTracker((1, 16, 9,9)) - self.st.permute((1,0,2,3)) - assert self.st.contiguous - self.st = CheckingShapeTracker((2, 16, 9,9)) - self.st.permute((1,0,2,3)) - assert not self.st.contiguous + def test_permute_1s_simple(self): + self.st = CheckingShapeTracker((1, 16, 9, 9)) + self.st.permute((1, 0, 2, 3)) + assert self.st.contiguous + self.st = CheckingShapeTracker((2, 16, 9, 9)) + self.st.permute((1, 0, 2, 3)) + assert not self.st.contiguous - def test_remove_1s_simple(self): - self.st = CheckingShapeTracker((1, 16, 1, 1)) - self.st.reshape((16,)) - assert self.st.contiguous + def test_remove_1s_simple(self): + self.st = CheckingShapeTracker((1, 16, 1, 1)) + self.st.reshape((16,)) + assert self.st.contiguous - def test_remove_1s(self): - self.st = CheckingShapeTracker((1, 4, 1, 4, 1)) - self.st.permute((0,3,2,1,4)) - self.st.reshape((4,4)) - assert not self.st.contiguous - self.st.permute((1,0)) - assert self.st.contiguous + def test_remove_1s(self): + self.st = CheckingShapeTracker((1, 4, 1, 4, 1)) + self.st.permute((0, 3, 2, 1, 4)) + self.st.reshape((4, 4)) + assert not self.st.contiguous + self.st.permute((1, 0)) + assert self.st.contiguous - def test_permute_reshape(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((2, 2, 2, 2)) - # TODO: should also be tested by test_super_complex - assert len(self.st.views) == 1 + def test_permute_reshape(self): + self.st = CheckingShapeTracker((4, 4)) + self.st.permute((1, 0)) + self.st.reshape((2, 2, 2, 2)) + # TODO: should also be tested by test_super_complex + assert len(self.st.views) == 1 - def test_factorize_split(self): - self.st = CheckingShapeTracker((4, 4)) - self.st.permute((1,0)) - self.st.reshape((2, 2, 2, 2)) - self.st.permute((2,3,0,1)) - assert self.st.contiguous + def test_factorize_split(self): + self.st = CheckingShapeTracker((4, 4)) + self.st.permute((1, 0)) + self.st.reshape((2, 2, 2, 2)) + self.st.permute((2, 3, 0, 1)) + assert self.st.contiguous - def test_factorize_combine(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((4, 16)) - self.st.permute((1, 0)) - assert self.st.contiguous + def test_factorize_combine(self): + self.st = CheckingShapeTracker((4, 4, 4)) + self.st.permute((2, 0, 1)) + self.st.reshape((4, 16)) + self.st.permute((1, 0)) + assert self.st.contiguous - def test_factorize_combine_add_ones(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((4, 16, 1, 1)) - self.st.permute((1, 0, 2, 3)) - assert self.st.contiguous + def test_factorize_combine_add_ones(self): + self.st = CheckingShapeTracker((4, 4, 4)) + self.st.permute((2, 0, 1)) + self.st.reshape((4, 16, 1, 1)) + self.st.permute((1, 0, 2, 3)) + assert self.st.contiguous - def test_fancy_factorize(self): - self.st = CheckingShapeTracker((32, 3, 3, 1)) - self.st.reshape((8, 4, 3, 3)) - assert len(self.st.views) == 1 + def test_fancy_factorize(self): + self.st = CheckingShapeTracker((32, 3, 3, 1)) + self.st.reshape((8, 4, 3, 3)) + assert len(self.st.views) == 1 - def test_super_complex_2_fail(self): - self.st = CheckingShapeTracker((4, 4, 4)) - self.st.permute((2, 0, 1)) - self.st.reshape((16, 4)) - assert len(self.st.views) != 1 + def test_super_complex_2_fail(self): + self.st = CheckingShapeTracker((4, 4, 4)) + self.st.permute((2, 0, 1)) + self.st.reshape((16, 4)) + assert len(self.st.views) != 1 - def test_work(self): - self.st = CheckingShapeTracker((64, 1024, 4)) - self.st.reshape((1, 64, 128, 32)) - self.st.permute((0, 3, 1, 2)) - self.st.reshape((1, 32, 1, 64, 128)) - self.st.permute((0, 3, 4, 1, 2)) - assert self.st.contiguous + def test_work(self): + self.st = CheckingShapeTracker((64, 1024, 4)) + self.st.reshape((1, 64, 128, 32)) + self.st.permute((0, 3, 1, 2)) + self.st.reshape((1, 32, 1, 64, 128)) + self.st.permute((0, 3, 4, 1, 2)) + assert self.st.contiguous + + def test_work2(self): + self.st = CheckingShapeTracker((64, 1024, 4)) + self.st.reshape((1, 64, 128, 32)) + self.st.permute((0, 3, 1, 2)) + self.st.reshape((1, 1, 32, 64, 128)) + self.st.permute((0, 3, 4, 1, 2)) + self.st.reshape((64, 1024, 4)) + print(self.st.views) + assert self.st.contiguous - def test_work2(self): - self.st = CheckingShapeTracker((64, 1024, 4)) - self.st.reshape((1, 64, 128, 32)) - self.st.permute((0, 3, 1, 2)) - self.st.reshape((1, 1, 32, 64, 128)) - self.st.permute((0, 3, 4, 1, 2)) - self.st.reshape((64, 1024, 4)) - print(self.st.views) - assert self.st.contiguous class TestSingleShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((7,4)) + def setUp(self): + self.st = CheckingShapeTracker((7, 4)) - def tearDown(self): - self.st.assert_same() + def tearDown(self): + self.st.assert_same() - def test_reshape(self): - self.st.reshape((7,1,4)) - assert self.st.contiguous + def test_reshape(self): + self.st.reshape((7, 1, 4)) + assert self.st.contiguous - def test_permute(self): - self.st.permute((1,0)) - assert not self.st.contiguous + def test_permute(self): + self.st.permute((1, 0)) + assert not self.st.contiguous - def test_shrink(self): - self.st.shrink(((1,2), (0,4))) - assert not self.st.contiguous + def test_shrink(self): + self.st.shrink(((1, 2), (0, 4))) + assert not self.st.contiguous - def test_double_permute(self): - self.st.permute((1,0)) - self.st.permute((1,0)) - assert self.st.contiguous + def test_double_permute(self): + self.st.permute((1, 0)) + self.st.permute((1, 0)) + assert self.st.contiguous - def test_reshape_permute(self): - self.st.reshape((7,1,4)) - self.st.permute((0,1,2)) - assert self.st.contiguous + def test_reshape_permute(self): + self.st.reshape((7, 1, 4)) + self.st.permute((0, 1, 2)) + assert self.st.contiguous - def test_reshape_permute_yes(self): - self.st.reshape((7,1,4)) - self.st.permute((0,2,1)) - assert self.st.contiguous + def test_reshape_permute_yes(self): + self.st.reshape((7, 1, 4)) + self.st.permute((0, 2, 1)) + assert self.st.contiguous + + def test_reshape_permute_no(self): + self.st.reshape((4, 7)) + self.st.permute((1, 0)) + assert not self.st.contiguous - def test_reshape_permute_no(self): - self.st.reshape((4,7)) - self.st.permute((1,0)) - assert not self.st.contiguous class TestShapeTrackerFuzzFailures(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((3,3,3)) - def tearDown(self): - self.st.assert_same() - def test_case_1(self): - self.st.shrink(((1, 2), (1, 3), (1, 3))) - self.st.reshape((1, 4)) - self.st.shrink(((0, 1), (1, 3))) - print(self.st.st) - self.st = self.st.simplify() - print(self.st.st) - def test_case_2(self): - self.st.stride( (1, 1, -2) ) - self.st.reshape( (3, 6) ) - self.st.shrink( ((1, 2), (1, 5)) ) - self.st.stride( (1, -1) ) - def test_case_3(self): - self.st.shrink( ((0, 2), (0, 2), (0, 1)) ) - self.st.permute( (1, 0, 2) ) - self.st.reshape( (4,) ) - self.st.shrink( ((0, 3),) ) - self.st.stride( (-1,) ) - def test_case_4(self): - self.st.reshape( (3, 3, 3, 1) ) - self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) ) - self.st.shrink( ((0, 2), (1, 2), (0, 2), (0, 1)) ) - self.st.expand( (2, 1, 2, 3) ) + def setUp(self): + self.st = CheckingShapeTracker((3, 3, 3)) + + def tearDown(self): + self.st.assert_same() + + def test_case_1(self): + self.st.shrink(((1, 2), (1, 3), (1, 3))) + self.st.reshape((1, 4)) + self.st.shrink(((0, 1), (1, 3))) + print(self.st.st) + self.st = self.st.simplify() + print(self.st.st) + + def test_case_2(self): + self.st.stride((1, 1, -2)) + self.st.reshape((3, 6)) + self.st.shrink(((1, 2), (1, 5))) + self.st.stride((1, -1)) + + def test_case_3(self): + self.st.shrink(((0, 2), (0, 2), (0, 1))) + self.st.permute((1, 0, 2)) + self.st.reshape((4,)) + self.st.shrink(((0, 3),)) + self.st.stride((-1,)) + + def test_case_4(self): + self.st.reshape((3, 3, 3, 1)) + self.st.pad(((0, 0), (0, 0), (0, 0), (1, 1))) + self.st.shrink(((0, 2), (1, 2), (0, 2), (0, 1))) + self.st.expand((2, 1, 2, 3)) + class TestMaskedShapeTracker(unittest.TestCase): - def test_pad_1x1(self): - self.st = CheckingShapeTracker((1,1)) - self.st.pad(((1,1), (1,1))) - self.st.assert_same() + def test_pad_1x1(self): + self.st = CheckingShapeTracker((1, 1)) + self.st.pad(((1, 1), (1, 1))) + self.st.assert_same() + + def test_pad_2x2(self): + self.st = CheckingShapeTracker((2, 2)) + self.st.pad(((1, 1), (1, 1))) + self.st.assert_same() - def test_pad_2x2(self): - self.st = CheckingShapeTracker((2,2)) - self.st.pad(((1,1), (1,1))) - self.st.assert_same() class TestShapeTracker(unittest.TestCase): - def setUp(self): - self.st = CheckingShapeTracker((7,4)) - self.apply = lambda fxn: [fxn(x) for x in [self.st]] + def setUp(self): + self.st = CheckingShapeTracker((7, 4)) + self.apply = lambda fxn: [fxn(x) for x in [self.st]] - def tearDown(self): - self.st.assert_same() + def tearDown(self): + self.st.assert_same() - def test_noop(self): - pass + def test_noop(self): + pass - def test_simple_split(self): - self.test_permute() - self.apply(lambda x: x.reshape((prod(self.st.shape), ))) + def test_simple_split(self): + self.test_permute() + self.apply(lambda x: x.reshape((prod(self.st.shape),))) - def test_simple_pad(self): - self.st.pad(((1,1), (1,1))) + def test_simple_pad(self): + self.st.pad(((1, 1), (1, 1))) - def test_pad_shrink(self): - self.st.pad(((1,1), (1,1))) - self.st.shrink(((0,4), (0,4))) + def test_pad_shrink(self): + self.st.pad(((1, 1), (1, 1))) + self.st.shrink(((0, 4), (0, 4))) - def test_pad_one_sided(self): - self.st.pad(((0,1), (0,0))) + def test_pad_one_sided(self): + self.st.pad(((0, 1), (0, 0))) - def test_pad_reshape(self): - self.st.pad(((0,1), (0,0))) - self.st.reshape((8*4,)) + def test_pad_reshape(self): + self.st.pad(((0, 1), (0, 0))) + self.st.reshape((8 * 4,)) - def test_pad_pad(self): - self.st.pad(((1,1), (1,1))) - self.st.pad(((1,1), (1,1))) + def test_pad_pad(self): + self.st.pad(((1, 1), (1, 1))) + self.st.pad(((1, 1), (1, 1))) - def test_pad_permute(self): - self.st.pad(((1,1), (2,2))) - self.st.permute((1,0)) + def test_pad_permute(self): + self.st.pad(((1, 1), (2, 2))) + self.st.permute((1, 0)) - def test_pad_expand(self): - self.st.reshape((7,4,1)) - self.st.pad(((1,1), (1,1), (0,0))) - self.st.expand((9,6,4)) + def test_pad_expand(self): + self.st.reshape((7, 4, 1)) + self.st.pad(((1, 1), (1, 1), (0, 0))) + self.st.expand((9, 6, 4)) - def test_pad_expand_alt(self): - self.st.pad(((1,1), (1,1))) - self.st.reshape((9,6,1)) - self.st.expand((9,6,4)) + def test_pad_expand_alt(self): + self.st.pad(((1, 1), (1, 1))) + self.st.reshape((9, 6, 1)) + self.st.expand((9, 6, 4)) - def test_pad_stride(self): - self.st.pad(((1,4), (1,3))) - self.st.stride((2,2)) + def test_pad_stride(self): + self.st.pad(((1, 4), (1, 3))) + self.st.stride((2, 2)) - def test_pad_stride_neg(self): - self.st.pad(((1,2), (1,0))) - self.st.stride((-1,-1)) + def test_pad_stride_neg(self): + self.st.pad(((1, 2), (1, 0))) + self.st.stride((-1, -1)) - def test_pad_stride_both(self): - self.st.pad(((1,2), (1,0))) - self.st.stride((-2,-2)) + def test_pad_stride_both(self): + self.st.pad(((1, 2), (1, 0))) + self.st.stride((-2, -2)) - def test_shrink_pad(self): - self.st.shrink(((0,4), (0,4))) - self.st.pad(((1,1), (1,1))) + def test_shrink_pad(self): + self.st.shrink(((0, 4), (0, 4))) + self.st.pad(((1, 1), (1, 1))) - def test_reshape(self): - new_shape = self.st.shape[::-1] - self.apply(lambda x: x.reshape(new_shape)) + def test_reshape(self): + new_shape = self.st.shape[::-1] + self.apply(lambda x: x.reshape(new_shape)) - def test_permute(self): - if len(self.st.shape) == 2: self.apply(lambda x: x.permute((1,0))) - elif len(self.st.shape) == 3: self.apply(lambda x: x.permute((2,0,1))) + def test_permute(self): + if len(self.st.shape) == 2: + self.apply(lambda x: x.permute((1, 0))) + elif len(self.st.shape) == 3: + self.apply(lambda x: x.permute((2, 0, 1))) - def test_reshape_with_1(self): - new_shape = (self.st.shape[0], 1, self.st.shape[1]) - self.apply(lambda x: x.reshape(new_shape)) + def test_reshape_with_1(self): + new_shape = (self.st.shape[0], 1, self.st.shape[1]) + self.apply(lambda x: x.reshape(new_shape)) - def test_expand(self): - self.test_reshape_with_1() - new_shape = list(self.st.shape) - new_shape[1] = 2 - self.apply(lambda x: x.expand(tuple(new_shape))) + def test_expand(self): + self.test_reshape_with_1() + new_shape = list(self.st.shape) + new_shape[1] = 2 + self.apply(lambda x: x.expand(tuple(new_shape))) - def test_flip_0(self): - self.apply(lambda x: x.flip((0,))) + def test_flip_0(self): + self.apply(lambda x: x.flip((0,))) - def test_flip_1(self): - self.apply(lambda x: x.flip((1,))) + def test_flip_1(self): + self.apply(lambda x: x.flip((1,))) - def test_flip_01(self): - self.apply(lambda x: x.flip((0,1))) + def test_flip_01(self): + self.apply(lambda x: x.flip((0, 1))) - def test_slice_0(self): - self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1])))) + def test_slice_0(self): + self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1])))) - def test_slice_1(self): - self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1])))) + def test_slice_1(self): + self.apply(lambda x: x.shrink(((0, x.shape[0]), (1, x.shape[1])))) - def test_slice_1c1(self): - self.apply(lambda x: x.shrink(((0, 1), (0, 1)))) + def test_slice_1c1(self): + self.apply(lambda x: x.shrink(((0, 1), (0, 1)))) - def test_slice_1c2(self): - self.apply(lambda x: x.shrink(((1, 2), (1, 2)))) + def test_slice_1c2(self): + self.apply(lambda x: x.shrink(((1, 2), (1, 2)))) - def test_double_permute(self): - self.apply(lambda x: x.permute((1, 0))) - self.apply(lambda x: x.permute((1, 0))) + def test_double_permute(self): + self.apply(lambda x: x.permute((1, 0))) + self.apply(lambda x: x.permute((1, 0))) - def test_slice_permute(self): - self.apply(lambda x: x.shrink(((0, 2), (2, 4)))) - self.apply(lambda x: x.permute((1, 0))) + def test_slice_permute(self): + self.apply(lambda x: x.shrink(((0, 2), (2, 4)))) + self.apply(lambda x: x.permute((1, 0))) - def test_slice_expand(self): - self.apply(lambda x: x.shrink(((0, 2), (3, 4)))) - self.apply(lambda x: x.expand((2, 10))) + def test_slice_expand(self): + self.apply(lambda x: x.shrink(((0, 2), (3, 4)))) + self.apply(lambda x: x.expand((2, 10))) - def test_double_stride(self): - self.apply(lambda x: x.stride((1, 2))) - self.apply(lambda x: x.stride((2, 1))) + def test_double_stride(self): + self.apply(lambda x: x.stride((1, 2))) + self.apply(lambda x: x.stride((2, 1))) - def test_stride(self): self.apply(lambda x: x.stride((2,1))) - def test_stride_int(self): self.apply(lambda x: x.stride((1,2))) - def test_stride_2(self): self.apply(lambda x: x.stride((2,2))) - def test_stride_n(self): self.apply(lambda x: x.stride((-2,1))) - def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2))) - def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2))) + def test_stride(self): + self.apply(lambda x: x.stride((2, 1))) - def test_reshape_then_permute(self): - self.test_reshape() - self.test_permute() + def test_stride_int(self): + self.apply(lambda x: x.stride((1, 2))) - def test_reshape_then_expand(self): - self.test_reshape() - self.test_expand() + def test_stride_2(self): + self.apply(lambda x: x.stride((2, 2))) - def test_permute_then_reshape(self): - self.test_permute() - self.test_reshape() + def test_stride_n(self): + self.apply(lambda x: x.stride((-2, 1))) - def test_expand_then_reshape(self): - self.test_expand() - self.test_reshape() + def test_stride_int_n(self): + self.apply(lambda x: x.stride((-1, 2))) + + def test_stride_2_n(self): + self.apply(lambda x: x.stride((-2, -2))) + + def test_reshape_then_permute(self): + self.test_reshape() + self.test_permute() + + def test_reshape_then_expand(self): + self.test_reshape() + self.test_expand() + + def test_permute_then_reshape(self): + self.test_permute() + self.test_reshape() + + def test_expand_then_reshape(self): + self.test_expand() + self.test_reshape() + + def test_combo(self): + self.test_permute() + self.test_reshape() + self.test_slice_1() + self.test_expand() + self.test_permute() - def test_combo(self): - self.test_permute() - self.test_reshape() - self.test_slice_1() - self.test_expand() - self.test_permute() class TestGetContraction(unittest.TestCase): - def test_contraction(self): - r = get_contraction((1,2,3,4), (2,3,4)) - self.assertEqual(r, [[0, 1], [2], [3]]) + def test_contraction(self): + r = get_contraction((1, 2, 3, 4), (2, 3, 4)) + self.assertEqual(r, [[0, 1], [2], [3]]) - r = get_contraction((2,1,3,4), (2,3,4)) - self.assertEqual(r, [[0], [1, 2], [3]]) + r = get_contraction((2, 1, 3, 4), (2, 3, 4)) + self.assertEqual(r, [[0], [1, 2], [3]]) - r = get_contraction((1,2,3,1,4), (1,2,3,4)) - self.assertEqual(r, [[], [0, 1], [2], [3, 4]]) + r = get_contraction((1, 2, 3, 1, 4), (1, 2, 3, 4)) + self.assertEqual(r, [[], [0, 1], [2], [3, 4]]) - r = get_contraction((1,2,3,1,4,1,1), (2,3,4)) - self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]]) + r = get_contraction((1, 2, 3, 1, 4, 1, 1), (2, 3, 4)) + self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]]) - r = get_contraction((1,2,3,4), (1,2,3*4)) - self.assertEqual(r, [[], [0, 1], [2, 3]]) + r = get_contraction((1, 2, 3, 4), (1, 2, 3 * 4)) + self.assertEqual(r, [[], [0, 1], [2, 3]]) - r = get_contraction((1,2,3,4), (2,1,3,4)) - self.assertEqual(r, [[0, 1], [], [2], [3]]) + r = get_contraction((1, 2, 3, 4), (2, 1, 3, 4)) + self.assertEqual(r, [[0, 1], [], [2], [3]]) - r = get_contraction((1,2,3,4), (1,1,2*3*4,1)) - self.assertEqual(r, [[], [], [0,1,2,3], []]) + r = get_contraction((1, 2, 3, 4), (1, 1, 2 * 3 * 4, 1)) + self.assertEqual(r, [[], [], [0, 1, 2, 3], []]) - r = get_contraction((2,1,3,4), (1,2,3,4)) - self.assertEqual(r, [[], [0], [1, 2], [3]]) + r = get_contraction((2, 1, 3, 4), (1, 2, 3, 4)) + self.assertEqual(r, [[], [0], [1, 2], [3]]) - r = get_contraction((1,2,3,4), (2*3*4,1,1,1)) - self.assertEqual(r, [[0, 1, 2, 3], [], [], []]) + r = get_contraction((1, 2, 3, 4), (2 * 3 * 4, 1, 1, 1)) + self.assertEqual(r, [[0, 1, 2, 3], [], [], []]) - r = get_contraction((4,4,4,4), (16,1,16)) - self.assertEqual(r, [[0, 1], [], [2, 3]]) + r = get_contraction((4, 4, 4, 4), (16, 1, 16)) + self.assertEqual(r, [[0, 1], [], [2, 3]]) - r = get_contraction((1,2,3,4,1,1,1), (2,3,4)) - self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]]) + r = get_contraction((1, 2, 3, 4, 1, 1, 1), (2, 3, 4)) + self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]]) - r = get_contraction((1,2,3,4), (1,2,3,4,1)) - self.assertEqual(r, [[], [0, 1], [2], [3], []]) + r = get_contraction((1, 2, 3, 4), (1, 2, 3, 4, 1)) + self.assertEqual(r, [[], [0, 1], [2], [3], []]) - r = get_contraction((14,1,384,14,1,1,1,1), (1,14,384,14)) - self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7]]) + r = get_contraction((14, 1, 384, 14, 1, 1, 1, 1), (1, 14, 384, 14)) + self.assertEqual(r, [[], [0], [1, 2], [3, 4, 5, 6, 7]]) - r = get_contraction((14,1,384,1,14,1,1,1,1), (1,14,384,14)) - self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7,8]]) + r = get_contraction((14, 1, 384, 1, 14, 1, 1, 1, 1), (1, 14, 384, 14)) + self.assertEqual(r, [[], [0], [1, 2], [3, 4, 5, 6, 7, 8]]) - r = get_contraction((512, 512), (1, 1, 512, 1, 1, 1, 1, 512)) - self.assertEqual(r, [[], [], [0], [], [], [], [], [1]]) + r = get_contraction((512, 512), (1, 1, 512, 1, 1, 1, 1, 512)) + self.assertEqual(r, [[], [], [0], [], [], [], [], [1]]) - r = get_contraction((1,2,3,4), (1,2,6,2)) - self.assertEqual(r, None) + r = get_contraction((1, 2, 3, 4), (1, 2, 6, 2)) + self.assertEqual(r, None) - def test_contraction_ones(self): - r = get_contraction((1,), (1,1,1)) - self.assertEqual(r, [[], [], [0]]) + def test_contraction_ones(self): + r = get_contraction((1,), (1, 1, 1)) + self.assertEqual(r, [[], [], [0]]) - r = get_contraction((1,1), (1,1,1)) - self.assertEqual(r, [[], [], [0, 1]]) + r = get_contraction((1, 1), (1, 1, 1)) + self.assertEqual(r, [[], [], [0, 1]]) - r = get_contraction((1,1,1,1), (1,)) - self.assertEqual(r, [[0,1,2,3]]) + r = get_contraction((1, 1, 1, 1), (1,)) + self.assertEqual(r, [[0, 1, 2, 3]]) - r = get_contraction((1,1,1,1), (1,1)) - self.assertEqual(r, [[], [0,1,2,3]]) + r = get_contraction((1, 1, 1, 1), (1, 1)) + self.assertEqual(r, [[], [0, 1, 2, 3]]) - r = get_contraction((1,1,1,1), (1,1,1)) - self.assertEqual(r, [[], [], [0,1,2,3]]) + r = get_contraction((1, 1, 1, 1), (1, 1, 1)) + self.assertEqual(r, [[], [], [0, 1, 2, 3]]) + + r = get_contraction((1, 1, 1, 1), (1, 1, 1, 1)) + self.assertEqual(r, [[], [], [], [0, 1, 2, 3]]) - r = get_contraction((1,1,1,1), (1,1,1,1)) - self.assertEqual(r, [[], [], [], [0,1,2,3]]) class TestShapeTrackerSize(unittest.TestCase): - def test_simple_size(self): - st = ShapeTracker.from_shape((100, 100)) - self.assertEqual(st.size(), 100*100) + def test_simple_size(self): + st = ShapeTracker.from_shape((100, 100)) + self.assertEqual(st.size(), 100 * 100) - def test_expand_size(self): - st = ShapeTracker.from_shape((100, 100)) - st = st.reshape((100, 100, 1)) - st = st.expand((100, 100, 100)) - self.assertEqual(st.size(), 100*100) + def test_expand_size(self): + st = ShapeTracker.from_shape((100, 100)) + st = st.reshape((100, 100, 1)) + st = st.expand((100, 100, 100)) + self.assertEqual(st.size(), 100 * 100) - def test_expand_size_flatten(self): - st = ShapeTracker.from_shape((100, 100)) - st = st.reshape((100, 100, 1)) - st = st.expand((100, 100, 100)) - st = st.reshape((100*100*100,)) - self.assertEqual(st.size(), 100*100) + def test_expand_size_flatten(self): + st = ShapeTracker.from_shape((100, 100)) + st = st.reshape((100, 100, 1)) + st = st.expand((100, 100, 100)) + st = st.reshape((100 * 100 * 100,)) + self.assertEqual(st.size(), 100 * 100) - def test_shrink_size_axis_0(self): - st = ShapeTracker.from_shape((100, 100)) - st = st.shrink(((0, 50), (0, 100))) - self.assertEqual(st.size(), 50*100) + def test_shrink_size_axis_0(self): + st = ShapeTracker.from_shape((100, 100)) + st = st.shrink(((0, 50), (0, 100))) + self.assertEqual(st.size(), 50 * 100) - def test_shrink_size_axis_0_variable(self): - st = ShapeTracker.from_shape((100, 100)) - st = st.shrink(((0, Variable("a", 0, 50)), (0, 100))) - self.assertEqual(st.size(), 50*100) + def test_shrink_size_axis_0_variable(self): + st = ShapeTracker.from_shape((100, 100)) + st = st.shrink(((0, Variable("a", 0, 50)), (0, 100))) + self.assertEqual(st.size(), 50 * 100) - def test_shrink_size_axis_1(self): - st = ShapeTracker.from_shape((100, 100)) - st = st.shrink(((0, 100), (0, 50))) - self.assertEqual(st.size(), 9950) # careful here + def test_shrink_size_axis_1(self): + st = ShapeTracker.from_shape((100, 100)) + st = st.shrink(((0, 100), (0, 50))) + self.assertEqual(st.size(), 9950) # careful here -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit/test_shm_tensor.py b/test/unit/test_shm_tensor.py index 6c9ab2486..a1d4d96ec 100644 --- a/test/unit/test_shm_tensor.py +++ b/test/unit/test_shm_tensor.py @@ -4,35 +4,37 @@ from tinygrad.helpers import CI from tinygrad.tensor import Tensor, Device import numpy as np + class TestRawShmBuffer(unittest.TestCase): - def test_e2e(self): - t = Tensor.randn(2, 2, 2).realize() + def test_e2e(self): + t = Tensor.randn(2, 2, 2).realize() - # copy to shm - shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name - s.close() - t_shm = t.to(f"disk:shm:{shm_name}").realize() + # copy to shm + shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name + s.close() + t_shm = t.to(f"disk:shm:{shm_name}").realize() - # copy from shm - t2 = t_shm.to(Device.DEFAULT).realize() + # copy from shm + t2 = t_shm.to(Device.DEFAULT).realize() - assert np.allclose(t.numpy(), t2.numpy()) - s.unlink() + assert np.allclose(t.numpy(), t2.numpy()) + s.unlink() - @unittest.skipIf(CI, "CI doesn't like big shared memory") - def test_e2e_big(self): - t = Tensor.randn(2048, 2048, 8).realize() + @unittest.skipIf(CI, "CI doesn't like big shared memory") + def test_e2e_big(self): + t = Tensor.randn(2048, 2048, 8).realize() - # copy to shm - shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name - s.close() - t_shm = t.to(f"disk:shm:{shm_name}").realize() + # copy to shm + shm_name = (s := shared_memory.SharedMemory(create=True, size=t.nbytes())).name + s.close() + t_shm = t.to(f"disk:shm:{shm_name}").realize() - # copy from shm - t2 = t_shm.to(Device.DEFAULT).realize() + # copy from shm + t2 = t_shm.to(Device.DEFAULT).realize() + + assert np.allclose(t.numpy(), t2.numpy()) + s.unlink() - assert np.allclose(t.numpy(), t2.numpy()) - s.unlink() if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 664a3c958..08c818a0c 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,481 +1,660 @@ #!/usr/bin/env python import unittest -from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer, create_rednode +from tinygrad.shape.symbolic import ( + MulNode, + SumNode, + Variable, + NumNode, + LtNode, + ModNode, + sym_render, + sym_infer, + create_rednode, +) + class TestSymbolic(unittest.TestCase): - def helper_test_variable(self, v, n, m, s): - self.assertEqual(v.render(), s) - self.assertEqual(v.min, n) - self.assertEqual(v.max, m) + def helper_test_variable(self, v, n, m, s): + self.assertEqual(v.render(), s) + self.assertEqual(v.min, n) + self.assertEqual(v.max, m) - def test_ge(self): - self.helper_test_variable(Variable("a", 3, 8)>=77, 0, 0, "0") - self.helper_test_variable(Variable("a", 3, 8)>=9, 0, 0, "0") - self.helper_test_variable(Variable("a", 3, 8)>=8, 0, 1, "((a*-1)<-7)") - self.helper_test_variable(Variable("a", 3, 8)>=4, 0, 1, "((a*-1)<-3)") - self.helper_test_variable(Variable("a", 3, 8)>=3, 1, 1, "1") - self.helper_test_variable(Variable("a", 3, 8)>=2, 1, 1, "1") + def test_ge(self): + self.helper_test_variable(Variable("a", 3, 8) >= 77, 0, 0, "0") + self.helper_test_variable(Variable("a", 3, 8) >= 9, 0, 0, "0") + self.helper_test_variable(Variable("a", 3, 8) >= 8, 0, 1, "((a*-1)<-7)") + self.helper_test_variable(Variable("a", 3, 8) >= 4, 0, 1, "((a*-1)<-3)") + self.helper_test_variable(Variable("a", 3, 8) >= 3, 1, 1, "1") + self.helper_test_variable(Variable("a", 3, 8) >= 2, 1, 1, "1") - def test_lt(self): - self.helper_test_variable(Variable("a", 3, 8)<77, 1, 1, "1") - self.helper_test_variable(Variable("a", 3, 8)<9, 1, 1, "1") - self.helper_test_variable(Variable("a", 3, 8)<8, 0, 1, "(a<8)") - self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)") - self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0") - self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0") + def test_lt(self): + self.helper_test_variable(Variable("a", 3, 8) < 77, 1, 1, "1") + self.helper_test_variable(Variable("a", 3, 8) < 9, 1, 1, "1") + self.helper_test_variable(Variable("a", 3, 8) < 8, 0, 1, "(a<8)") + self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)") + self.helper_test_variable(Variable("a", 3, 8) < 3, 0, 0, "0") + self.helper_test_variable(Variable("a", 3, 8) < 2, 0, 0, "0") - def test_ge_divides(self): - expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512 - self.helper_test_variable(expr, 0, 1, "(idx<128)") + def test_ge_divides(self): + expr = (Variable("idx", 0, 511) * 4 + Variable("FLOAT4_INDEX", 0, 3)) < 512 + self.helper_test_variable(expr, 0, 1, "(idx<128)") - def test_ge_divides_and(self): - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, - (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) - self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))") - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, - (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) - self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))") + def test_ge_divides_and(self): + expr = Variable.ands( + [ + (Variable("idx1", 0, 511) * 4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511) * 4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + ] + ) + self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))") + expr = Variable.ands( + [ + (Variable("idx1", 0, 511) * 4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511) * 4 + Variable("FLOAT8_INDEX", 0, 7)) < 512, + ] + ) + self.helper_test_variable( + expr // 4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))" + ) - def test_lt_factors(self): - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512]) - self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") + def test_lt_factors(self): + expr = Variable.ands( + [(Variable("idx1", 0, 511) * 4 + Variable("FLOAT4_INDEX", 0, 256)) < 512] + ) + self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") - def test_div_becomes_num(self): - assert isinstance(Variable("a", 2, 3)//2, NumNode) + def test_div_becomes_num(self): + assert isinstance(Variable("a", 2, 3) // 2, NumNode) - def test_var_becomes_num(self): - assert isinstance(Variable("a", 2, 2), NumNode) + def test_var_becomes_num(self): + assert isinstance(Variable("a", 2, 2), NumNode) - def test_equality(self): - idx1 = Variable("idx1", 0, 3) - idx2 = Variable("idx2", 0, 3) - assert idx1 == idx1 - assert idx1 != idx2 - assert idx1*4 == idx1*4 - assert idx1*4 != idx1*3 - assert idx1*4 != idx1+4 - assert idx1*4 != idx2*4 - assert idx1+idx2 == idx1+idx2 - assert idx1+idx2 == idx2+idx1 - assert idx1+idx2 != idx2 - assert idx1*idx2 == idx2*idx1 + def test_equality(self): + idx1 = Variable("idx1", 0, 3) + idx2 = Variable("idx2", 0, 3) + assert idx1 == idx1 + assert idx1 != idx2 + assert idx1 * 4 == idx1 * 4 + assert idx1 * 4 != idx1 * 3 + assert idx1 * 4 != idx1 + 4 + assert idx1 * 4 != idx2 * 4 + assert idx1 + idx2 == idx1 + idx2 + assert idx1 + idx2 == idx2 + idx1 + assert idx1 + idx2 != idx2 + assert idx1 * idx2 == idx2 * idx1 - def test_factorize(self): - a = Variable("a", 0, 8) - self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)") + def test_factorize(self): + a = Variable("a", 0, 8) + self.helper_test_variable(a * 2 + a * 3, 0, 8 * 5, "(a*5)") - def test_factorize_no_mul(self): - a = Variable("a", 0, 8) - self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)") + def test_factorize_no_mul(self): + a = Variable("a", 0, 8) + self.helper_test_variable(a + a * 3, 0, 8 * 4, "(a*4)") - def test_neg(self): - self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") + def test_neg(self): + self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") - def test_add_1(self): - self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)") + def test_add_1(self): + self.helper_test_variable(Variable("a", 0, 8) + 1, 1, 9, "(1+a)") - def test_add_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)+NumNode(1), 1, 9, "(1+a)") + def test_add_num_1(self): + self.helper_test_variable(Variable("a", 0, 8) + NumNode(1), 1, 9, "(1+a)") - def test_sub_1(self): - self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)") + def test_sub_1(self): + self.helper_test_variable(Variable("a", 0, 8) - 1, -1, 7, "(-1+a)") - def test_sub_num_1(self): - self.helper_test_variable(Variable("a", 0, 8)-NumNode(1), -1, 7, "(-1+a)") + def test_sub_num_1(self): + self.helper_test_variable(Variable("a", 0, 8) - NumNode(1), -1, 7, "(-1+a)") - def test_mul_0(self): - self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0") + def test_mul_0(self): + self.helper_test_variable(Variable("a", 0, 8) * 0, 0, 0, "0") - def test_mul_1(self): - self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a") + def test_mul_1(self): + self.helper_test_variable(Variable("a", 0, 8) * 1, 0, 8, "a") - def test_mul_neg_1(self): - self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((((a*-1)+3)//3)+-1)") + def test_mul_neg_1(self): + self.helper_test_variable( + (Variable("a", 0, 2) * -1) // 3, -1, 0, "((((a*-1)+3)//3)+-1)" + ) - def test_mul_2(self): - self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)") + def test_mul_2(self): + self.helper_test_variable(Variable("a", 0, 8) * 2, 0, 16, "(a*2)") - def test_div_1(self): - self.helper_test_variable(Variable("a", 0, 8)//1, 0, 8, "a") + def test_div_1(self): + self.helper_test_variable(Variable("a", 0, 8) // 1, 0, 8, "a") - def test_mod_1(self): - self.helper_test_variable(Variable("a", 0, 8)%1, 0, 0, "0") + def test_mod_1(self): + self.helper_test_variable(Variable("a", 0, 8) % 1, 0, 0, "0") - def test_add_min_max(self): - self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)") + def test_add_min_max(self): + self.helper_test_variable( + Variable("a", 0, 8) * 2 + 12, 12, 16 + 12, "((a*2)+12)" + ) - def test_div_min_max(self): - self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)") + def test_div_min_max(self): + self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)") - def test_div_neg_min_max(self): - self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") + def test_div_neg_min_max(self): + self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") - def test_sum_div_min_max(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") + def test_sum_div_min_max(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, + 0, + 5, + "((a+b)//2)", + ) - def test_sum_div_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") + def test_sum_div_factor(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3) * 4]) // 2, + 0, + 20, + "((a*2)+(b*2))", + ) - def test_sum_div_some_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") + def test_sum_div_some_factor(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 5, Variable("b", 0, 3) * 4]) // 2, + 0, + 23, + "(((a*5)//2)+(b*2))", + ) - def test_sum_div_some_partial_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") - self.helper_test_variable(Variable.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") + def test_sum_div_some_partial_factor(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 6, Variable("b", 0, 7) * 6]) // 16, + 0, + 5, + "(((a*3)+(b*3))//8)", + ) + self.helper_test_variable( + Variable.sum( + [NumNode(16), Variable("a", 0, 7) * 6, Variable("b", 0, 7) * 6] + ) + // 16, + 1, + 6, + "((((a*3)+(b*3))//8)+1)", + ) - def test_sum_div_no_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") + def test_sum_div_no_factor(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 5, Variable("b", 0, 3) * 5]) // 2, + 0, + 25, + "(((a*5)+(b*5))//2)", + ) - def test_mod_factor(self): - # NOTE: even though the mod max is 50, it can't know this without knowing about the mul - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") + def test_mod_factor(self): + # NOTE: even though the mod max is 50, it can't know this without knowing about the mul + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 100, Variable("b", 0, 3) * 50]) % 100, + 0, + 99, + "((b*50)%100)", + ) - def test_mod_to_sub(self): - # This is mod reduction - self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render()) + def test_mod_to_sub(self): + # This is mod reduction + self.helper_test_variable( + (1 + Variable("a", 1, 2)) % 2, 0, 1, (Variable("a", 1, 2) - 1).render() + ) - def test_sum_div_const(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") + def test_sum_div_const(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 4, NumNode(3)]) // 4, 0, 7, "a" + ) - def test_sum_div_const_big(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") + def test_sum_div_const_big(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 4, NumNode(3)]) // 16, 0, 1, "(a//4)" + ) - def test_sum_lt_fold(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") - self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)") + def test_sum_lt_fold(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, + 0, + 1, + "(a<4)", + ) + self.helper_test_variable( + Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, + 0, + 1, + "(((a*4)+b)<16)", + ) - def test_mod_mul(self): - self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a") + def test_mod_mul(self): + self.helper_test_variable((Variable("a", 0, 5) * 10) % 9, 0, 5, "a") - def test_mod_mod(self): - self.helper_test_variable((Variable("a", 0, 31)%12)%4, 0, 3, "(a%4)") - self.helper_test_variable(((4*Variable("a", 0, 31)) % 12) % 4, 0, 0, "0") - self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)") + def test_mod_mod(self): + self.helper_test_variable((Variable("a", 0, 31) % 12) % 4, 0, 3, "(a%4)") + self.helper_test_variable(((4 * Variable("a", 0, 31)) % 12) % 4, 0, 0, "0") + self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)") - def test_mul_mul(self): - self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)") + def test_mul_mul(self): + self.helper_test_variable( + (Variable("a", 0, 5) * 10) * 9, 0, 5 * 10 * 9, "(a*90)" + ) - def test_mul_lt(self): - self.helper_test_variable((Variable("a", 0, 5)*4)<13, 0, 1, "(a<4)") - self.helper_test_variable((Variable("a", 0, 5)*4)<16, 0, 1, "(a<4)") - self.helper_test_variable((Variable("a", 0, 5)*4)>11, 0, 1, "((a*-1)<-2)") - self.helper_test_variable((Variable("a", 0, 5)*4)>12, 0, 1, "((a*-1)<-3)") + def test_mul_lt(self): + self.helper_test_variable((Variable("a", 0, 5) * 4) < 13, 0, 1, "(a<4)") + self.helper_test_variable((Variable("a", 0, 5) * 4) < 16, 0, 1, "(a<4)") + self.helper_test_variable((Variable("a", 0, 5) * 4) > 11, 0, 1, "((a*-1)<-2)") + self.helper_test_variable((Variable("a", 0, 5) * 4) > 12, 0, 1, "((a*-1)<-3)") - def test_div_div(self): - self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") + def test_div_div(self): + self.helper_test_variable((Variable("a", 0, 1800) // 10) // 9, 0, 20, "(a//90)") - def test_distribute_mul(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") + def test_distribute_mul(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)]) * 3, + 0, + 24, + "((a*3)+(b*3))", + ) - def test_mod_mul_sum(self): - self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") + def test_mod_mul_sum(self): + self.helper_test_variable( + Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5) * 10]) % 9, + 0, + 7, + "(a+b)", + ) - def test_sum_0(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") + def test_sum_0(self): + self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") - def test_mod_remove(self): - self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") + def test_mod_remove(self): + self.helper_test_variable(Variable("a", 0, 6) % 100, 0, 6, "a") - def test_big_mod(self): - # NOTE: we no longer support negative variables - #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") - #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") - #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") - self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") - #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") + def test_big_mod(self): + # NOTE: we no longer support negative variables + # self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") + # self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") + # self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") + self.helper_test_variable(Variable("a", 0, 20) % 10, 0, 9, "(a%10)") + # self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") - def test_gt_remove(self): - self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0") + def test_gt_remove(self): + self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0") - def test_lt_remove(self): - self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "0") - self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)") - self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1") + def test_lt_remove(self): + self.helper_test_variable(Variable("a", 0, 6) < -3, 0, 0, "0") + self.helper_test_variable(Variable("a", 0, 6) < 3, 0, 1, "(a<3)") + self.helper_test_variable(Variable("a", 0, 6) < 8, 1, 1, "1") - def test_lt_sum_remove(self): - self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)") + def test_lt_sum_remove(self): + self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)") - def test_and_fold(self): - self.helper_test_variable(Variable.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0") + def test_and_fold(self): + self.helper_test_variable( + Variable.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0" + ) - def test_and_remove(self): - self.helper_test_variable(Variable.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") + def test_and_remove(self): + self.helper_test_variable( + Variable.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a" + ) - def test_mod_factor_negative(self): - self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") - self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") + def test_mod_factor_negative(self): + self.helper_test_variable( + Variable.sum( + [NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10) * 28] + ) + % 28, + 0, + 27, + "((27+a)%28)", + ) + self.helper_test_variable( + Variable.sum( + [NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10) * 28] + ) + % 28, + 0, + 27, + "((27+a)%28)", + ) - def test_sum_combine_num(self): - self.helper_test_variable(Variable.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)") + def test_sum_combine_num(self): + self.helper_test_variable( + Variable.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), + 6, + 16, + "(6+a)", + ) - def test_sum_num_hoisted_and_factors_cancel_out(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") + def test_sum_num_hoisted_and_factors_cancel_out(self): + self.helper_test_variable( + Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), + 1, + 1, + "1", + ) - def test_div_factor(self): - self.helper_test_variable(Variable.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") + def test_div_factor(self): + self.helper_test_variable( + Variable.sum( + [NumNode(-40), Variable("a", 0, 10) * 2, Variable("b", 0, 10) * 40] + ) + // 40, + -1, + 9, + "(-1+b)", + ) - def test_mul_div(self): - self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") + def test_mul_div(self): + self.helper_test_variable((Variable("a", 0, 10) * 4) // 4, 0, 10, "a") - def test_mul_div_factor_mul(self): - self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)") + def test_mul_div_factor_mul(self): + self.helper_test_variable((Variable("a", 0, 10) * 8) // 4, 0, 20, "(a*2)") - def test_mul_div_factor_div(self): - self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") + def test_mul_div_factor_div(self): + self.helper_test_variable((Variable("a", 0, 10) * 4) // 8, 0, 5, "(a//2)") - def test_div_remove(self): - self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") + def test_div_remove(self): + self.helper_test_variable( + Variable.sum([Variable("idx0", 0, 127) * 4, Variable("idx2", 0, 3)]) // 4, + 0, + 127, + "idx0", + ) - def test_div_numerator_negative(self): - self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") + def test_div_numerator_negative(self): + self.helper_test_variable( + (Variable("idx", 0, 9) * -10) // 11, -9, 0, "((((idx*-10)+99)//11)+-9)" + ) + + def test_div_into_mod(self): + self.helper_test_variable( + (Variable("idx", 0, 16) * 4) % 8 // 4, 0, 1, "(idx%2)" + ) - def test_div_into_mod(self): - self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)") class TestSymbolicNumeric(unittest.TestCase): - def helper_test_numeric(self, f): - # TODO: why are the negative tests broken? (even if we did support negative variables) - #MIN, MAX = -10, 10 - MIN, MAX = 0, 10 - # one number - for i in range(MIN, MAX): - v = f(NumNode(i)) - #print(i, f(i), v.min, v.max) - self.assertEqual(v.min, v.max) - self.assertEqual(v.min, f(i)) - for kmin in range(MIN, MAX): - for kmax in range(MIN, MAX): - if kmin > kmax: continue - v = f(Variable("tmp", kmin, kmax)) - values = [f(rv) for rv in range(kmin, kmax+1)] - # the min and max may not be exact - self.assertLessEqual(v.min, min(values)) - self.assertGreaterEqual(v.max, max(values)) + def helper_test_numeric(self, f): + # TODO: why are the negative tests broken? (even if we did support negative variables) + # MIN, MAX = -10, 10 + MIN, MAX = 0, 10 + # one number + for i in range(MIN, MAX): + v = f(NumNode(i)) + # print(i, f(i), v.min, v.max) + self.assertEqual(v.min, v.max) + self.assertEqual(v.min, f(i)) + for kmin in range(MIN, MAX): + for kmax in range(MIN, MAX): + if kmin > kmax: + continue + v = f(Variable("tmp", kmin, kmax)) + values = [f(rv) for rv in range(kmin, kmax + 1)] + # the min and max may not be exact + self.assertLessEqual(v.min, min(values)) + self.assertGreaterEqual(v.max, max(values)) + + def test_mod_4(self): + self.helper_test_numeric(lambda x: (x % 4)) + + def test_div_4(self): + self.helper_test_numeric(lambda x: (x // 4)) + + def test_plus_1_div_2(self): + self.helper_test_numeric(lambda x: (x + 1) // 2) + + def test_plus_1_mod_2(self): + self.helper_test_numeric(lambda x: (x + 1) % 2) + + def test_times_2(self): + self.helper_test_numeric(lambda x: x * 2) + + def test_times_2_plus_3(self): + self.helper_test_numeric(lambda x: x * 2 + 3) + + def test_times_2_plus_3_mod_4(self): + self.helper_test_numeric(lambda x: (x * 2 + 3) % 4) + + def test_times_2_plus_3_div_4(self): + self.helper_test_numeric(lambda x: (x * 2 + 3) // 4) + + def test_times_2_plus_3_div_4_mod_4(self): + self.helper_test_numeric(lambda x: ((x * 2 + 3) // 4) % 4) - def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4)) - def test_div_4(self): self.helper_test_numeric(lambda x: (x//4)) - def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2) - def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2) - def test_times_2(self): self.helper_test_numeric(lambda x: x*2) - def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3) - def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4) - def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4) - def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4) class TestSymbolicVars(unittest.TestCase): - def test_simple(self): - z = NumNode(0) - a = Variable("a", 0, 10) - b = Variable("b", 0, 10) - c = Variable("c", 0, 10) - assert z.vars() == z.vars() == set() - assert a.vars() == a.vars() == {a} - m = MulNode(a, 3) - assert m.vars() == {a} - s = SumNode([a, b, c]) - assert s.vars() == {a, b, c} + def test_simple(self): + z = NumNode(0) + a = Variable("a", 0, 10) + b = Variable("b", 0, 10) + c = Variable("c", 0, 10) + assert z.vars() == z.vars() == set() + assert a.vars() == a.vars() == {a} + m = MulNode(a, 3) + assert m.vars() == {a} + s = SumNode([a, b, c]) + assert s.vars() == {a, b, c} - def test_compound(self): - a = Variable("a", 0, 10) - b = Variable("b", 0, 10) - c = Variable("c", 0, 10) - assert (a + b * c).vars() == {a, b, c} - assert (a % 3 + b // 5).vars() == {a, b} - assert (a + b + c - a).vars() == {b, c} + def test_compound(self): + a = Variable("a", 0, 10) + b = Variable("b", 0, 10) + c = Variable("c", 0, 10) + assert (a + b * c).vars() == {a, b, c} + assert (a % 3 + b // 5).vars() == {a, b} + assert (a + b + c - a).vars() == {b, c} + + def test_dedup(self): + a = Variable("a", 0, 10) + assert (a * a).vars() == {a} + assert (a // 4 + a // 6).vars() == {a} - def test_dedup(self): - a = Variable("a", 0, 10) - assert (a * a).vars() == {a} - assert (a//4 + a//6).vars() == {a} class TestSymbolicMinMax(unittest.TestCase): - def test_min_max_known(self): - a = Variable("a", 1, 8) - assert max(1, a) == max(a, 1) == a - assert min(1, a) == min(a, 1) == 1 + def test_min_max_known(self): + a = Variable("a", 1, 8) + assert max(1, a) == max(a, 1) == a + assert min(1, a) == min(a, 1) == 1 + class TestSymRender(unittest.TestCase): - def test_sym_render(self): - a = Variable("a", 1, 8) - b = Variable("b", 1, 10) - assert sym_render(a) == "a" - assert sym_render(1) == "1" - assert sym_render(a+1) == "(1+a)" - assert sym_render(a*b) == "(a*b)" + def test_sym_render(self): + a = Variable("a", 1, 8) + b = Variable("b", 1, 10) + assert sym_render(a) == "a" + assert sym_render(1) == "1" + assert sym_render(a + 1) == "(1+a)" + assert sym_render(a * b) == "(a*b)" + class TestSymInfer(unittest.TestCase): - def test_sym_infer(self): - a = Variable("a", 0, 10) - b = Variable("b", 0, 10) - c = Variable("c", 0, 10) - var_vals = {a: 2, b: 3, c: 4} - assert sym_infer(5, var_vals) == 5 - assert sym_infer(a, var_vals) == 2 - assert sym_infer(b, var_vals) == 3 - assert sym_infer(a+b, var_vals) == 5 - assert sym_infer(a-b, var_vals) == -1 - assert sym_infer(a+b+c, var_vals) == 9 - assert sym_infer(a*b, var_vals) == 6 - assert sym_infer(a*b+c, var_vals) == 10 + def test_sym_infer(self): + a = Variable("a", 0, 10) + b = Variable("b", 0, 10) + c = Variable("c", 0, 10) + var_vals = {a: 2, b: 3, c: 4} + assert sym_infer(5, var_vals) == 5 + assert sym_infer(a, var_vals) == 2 + assert sym_infer(b, var_vals) == 3 + assert sym_infer(a + b, var_vals) == 5 + assert sym_infer(a - b, var_vals) == -1 + assert sym_infer(a + b + c, var_vals) == 9 + assert sym_infer(a * b, var_vals) == 6 + assert sym_infer(a * b + c, var_vals) == 10 + class TestSymbolicSymbolicOps(unittest.TestCase): - def test_node_divmod_node(self): - i = Variable("i", 1, 10) - idx0 = Variable("idx0", 0, i*3-1) - assert NumNode(0) // (Variable("i", 1, 10)*128) == 0 - assert NumNode(0) % (Variable("i", 1, 10)*128) == 0 - assert NumNode(127) // (Variable("i", 1, 10)*128) == 0 - assert NumNode(127) % (Variable("i", 1, 10)*128) == 127 - assert 127 // (Variable("i", 1, 10)*128) == 0 - assert 127 % (Variable("i", 1, 10)*128) == 127 - assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0 - assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128 - assert 128 // (Variable("i", 1, 10)*128 + 128) == 0 - assert 128 % (Variable("i", 1, 10)*128 + 128) == 128 - assert 0 // (Variable("i", 1, 10)*128) == 0 - assert 0 % (Variable("i", 1, 10)*128) == 0 - assert idx0 // (i*3) == 0 - assert idx0 % (i*3) == idx0 - assert i // i == 1 - assert i % i == 0 - assert 128 // NumNode(4) == 32 - assert 128 % NumNode(4) == 0 - assert NumNode(128) // NumNode(4) == 32 - assert NumNode(128) % NumNode(4) == 0 + def test_node_divmod_node(self): + i = Variable("i", 1, 10) + idx0 = Variable("idx0", 0, i * 3 - 1) + assert NumNode(0) // (Variable("i", 1, 10) * 128) == 0 + assert NumNode(0) % (Variable("i", 1, 10) * 128) == 0 + assert NumNode(127) // (Variable("i", 1, 10) * 128) == 0 + assert NumNode(127) % (Variable("i", 1, 10) * 128) == 127 + assert 127 // (Variable("i", 1, 10) * 128) == 0 + assert 127 % (Variable("i", 1, 10) * 128) == 127 + assert NumNode(128) // (Variable("i", 1, 10) * 128 + 128) == 0 + assert NumNode(128) % (Variable("i", 1, 10) * 128 + 128) == 128 + assert 128 // (Variable("i", 1, 10) * 128 + 128) == 0 + assert 128 % (Variable("i", 1, 10) * 128 + 128) == 128 + assert 0 // (Variable("i", 1, 10) * 128) == 0 + assert 0 % (Variable("i", 1, 10) * 128) == 0 + assert idx0 // (i * 3) == 0 + assert idx0 % (i * 3) == idx0 + assert i // i == 1 + assert i % i == 0 + assert 128 // NumNode(4) == 32 + assert 128 % NumNode(4) == 0 + assert NumNode(128) // NumNode(4) == 32 + assert NumNode(128) % NumNode(4) == 0 - def test_mulnode_divmod_node(self): - i = Variable("i", 1, 10) - idx0 = Variable("idx0", 0, 31) - assert (idx0*(i*4+4)) // (i+1) == (idx0*4) - assert (idx0*(i*4+4)) % (i+1) == 0 - assert (idx0*i) % i == 0 + def test_mulnode_divmod_node(self): + i = Variable("i", 1, 10) + idx0 = Variable("idx0", 0, 31) + assert (idx0 * (i * 4 + 4)) // (i + 1) == (idx0 * 4) + assert (idx0 * (i * 4 + 4)) % (i + 1) == 0 + assert (idx0 * i) % i == 0 - def test_sumnode_divmod_sumnode(self): - i = Variable("i", 1, 10) - idx0 = Variable("idx0", 0, 7) - idx1 = Variable("idx1", 0, 3) - idx2 = Variable("idx2", 0, i) - assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1 - assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2 - assert (i+1) // (i*128+128) == 0 - assert (i+1) % (i*128+128) == (i+1) - assert (i+1+idx2) // (i+1) == 1 - assert (i+1+idx2) % (i+1) == idx2 - assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1 - assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2 - assert (i*128+128)*2 // (i*128+128) == 2 - assert (i*128+128)*2 % (i*128+128) == 0 + def test_sumnode_divmod_sumnode(self): + i = Variable("i", 1, 10) + idx0 = Variable("idx0", 0, 7) + idx1 = Variable("idx1", 0, 3) + idx2 = Variable("idx2", 0, i) + assert (idx0 * (i * 4 + 4) + idx1 * (i + 1) + idx2) // ( + i + 1 + ) == idx0 * 4 + idx1 + assert (idx0 * (i * 4 + 4) + idx1 * (i + 1) + idx2) % (i + 1) == idx2 + assert (i + 1) // (i * 128 + 128) == 0 + assert (i + 1) % (i * 128 + 128) == (i + 1) + assert (i + 1 + idx2) // (i + 1) == 1 + assert (i + 1 + idx2) % (i + 1) == idx2 + assert (idx0 * (i * 4 + 4) + i + 1 + idx2) // (i + 1) == idx0 * 4 + 1 + assert (idx0 * (i * 4 + 4) + i + 1 + idx2) % (i + 1) == idx2 + assert (i * 128 + 128) * 2 // (i * 128 + 128) == 2 + assert (i * 128 + 128) * 2 % (i * 128 + 128) == 0 - def test_sumnode_divmod_sumnode_complex(self): - i = Variable("i", 1, 1024) - gidx0 = Variable("gidx0", 0, i) - lidx1 = Variable("lidx1", 0, 7) - ridx2 = Variable("ridx1", 0, 31) - assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) // (i*128+128) == 2 + lidx1*4 - assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) % (i*128+128) == gidx0*128 + ridx2*4 - assert ((gidx0*128+i*128+ridx2*4+129)) // (i*128+128) == 1 - assert ((gidx0*128+i*128+ridx2*4+129)) % (i*128+128) == gidx0*128 + ridx2*4 + 1 - assert (ridx2*(i*4+4)+1+i+gidx0) // (i*128+128) == 0 - assert (ridx2*(i*4+4)+1+i+gidx0) % (i*128+128) == (ridx2*(i*4+4)+1+i+gidx0) + def test_sumnode_divmod_sumnode_complex(self): + i = Variable("i", 1, 1024) + gidx0 = Variable("gidx0", 0, i) + lidx1 = Variable("lidx1", 0, 7) + ridx2 = Variable("ridx1", 0, 31) + assert ( + (i * 128 + 128) * 2 + gidx0 * 128 + lidx1 * (i * 512 + 512) + ridx2 * 4 + ) // (i * 128 + 128) == 2 + lidx1 * 4 + assert ( + (i * 128 + 128) * 2 + gidx0 * 128 + lidx1 * (i * 512 + 512) + ridx2 * 4 + ) % (i * 128 + 128) == gidx0 * 128 + ridx2 * 4 + assert ((gidx0 * 128 + i * 128 + ridx2 * 4 + 129)) // (i * 128 + 128) == 1 + assert ((gidx0 * 128 + i * 128 + ridx2 * 4 + 129)) % ( + i * 128 + 128 + ) == gidx0 * 128 + ridx2 * 4 + 1 + assert (ridx2 * (i * 4 + 4) + 1 + i + gidx0) // (i * 128 + 128) == 0 + assert (ridx2 * (i * 4 + 4) + 1 + i + gidx0) % (i * 128 + 128) == ( + ridx2 * (i * 4 + 4) + 1 + i + gidx0 + ) - def test_mod_node_max(self): - i = Variable("i", 1, 128) - gidx0 = Variable("gidx0", 0, i) - mod = gidx0 % 8 - assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 - mod = gidx0 % 2 - assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 + def test_mod_node_max(self): + i = Variable("i", 1, 128) + gidx0 = Variable("gidx0", 0, i) + mod = gidx0 % 8 + assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 + mod = gidx0 % 2 + assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 - gidx0 = Variable("gidx0", 0, i*8+7) - mod = gidx0 % 8 - assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 - mod = gidx0 % 2 - assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 + gidx0 = Variable("gidx0", 0, i * 8 + 7) + mod = gidx0 % 8 + assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 8 + mod = gidx0 % 2 + assert isinstance(mod, ModNode) and mod.a == gidx0 and mod.b == 2 - def test_node_lt_node(self): - a = Variable("a", 1, 5) - b = Variable("b", 6, 9) - c = Variable("c", 1, 10) - d = Variable("d", 5, 10) - # if the value is always the same, it folds to num - assert (a < b) == 1 - assert (b < a) == 0 - assert (d < a) == 0 - # if it remains as a LtNode, bool is always true and (min, max) == (0, 1) - assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1 - assert a < c - assert isinstance((a > c), LtNode) and (a > c).min == 0 and (a > c).max == 1 - # same when comparing with a constant - assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1 - assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1 + def test_node_lt_node(self): + a = Variable("a", 1, 5) + b = Variable("b", 6, 9) + c = Variable("c", 1, 10) + d = Variable("d", 5, 10) + # if the value is always the same, it folds to num + assert (a < b) == 1 + assert (b < a) == 0 + assert (d < a) == 0 + # if it remains as a LtNode, bool is always true and (min, max) == (0, 1) + assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1 + assert a < c + assert isinstance((a > c), LtNode) and (a > c).min == 0 and (a > c).max == 1 + # same when comparing with a constant + assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1 + assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1 - def test_sumnode_mulnode_lt(self): - a = Variable("a", 1, 2) - b = Variable("b", 1, 2) - c = Variable("c", 1, 2) - x = SumNode([MulNode(a, b), c]) - with self.assertRaises(AssertionError): - (x < 3) + def test_sumnode_mulnode_lt(self): + a = Variable("a", 1, 2) + b = Variable("b", 1, 2) + c = Variable("c", 1, 2) + x = SumNode([MulNode(a, b), c]) + with self.assertRaises(AssertionError): + (x < 3) - def test_nested_variable_mod(self): - i = Variable("i", 1, 5) - idx0 = Variable("idx0", 0, i) - with self.assertRaises(AssertionError): - assert idx0 % 2 == idx0 + def test_nested_variable_mod(self): + i = Variable("i", 1, 5) + idx0 = Variable("idx0", 0, i) + with self.assertRaises(AssertionError): + assert idx0 % 2 == idx0 - def test_num_node_mul_node(self): - a = Variable("a", 1, 5) - b = NumNode(2) * a - assert b == a * 2 - assert isinstance(b, MulNode) - b = NumNode(1) * a - assert b == a - assert isinstance(b, Variable) - b = NumNode(0) * a - assert b == 0 - assert isinstance(b, NumNode) + def test_num_node_mul_node(self): + a = Variable("a", 1, 5) + b = NumNode(2) * a + assert b == a * 2 + assert isinstance(b, MulNode) + b = NumNode(1) * a + assert b == a + assert isinstance(b, Variable) + b = NumNode(0) * a + assert b == 0 + assert isinstance(b, NumNode) - def test_num_node_expand(self): - a = NumNode(42) - assert a.expand() == [a] + def test_num_node_expand(self): + a = NumNode(42) + assert a.expand() == [a] - def test_variable_expand(self): - a = Variable("a", 5, 7) - assert a.expand() == [a] + def test_variable_expand(self): + a = Variable("a", 5, 7) + assert a.expand() == [a] - def test_variable_expand_expr_none(self): - a = Variable(None, 5, 7) - assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)] + def test_variable_expand_expr_none(self): + a = Variable(None, 5, 7) + assert a.expand() == [NumNode(5), NumNode(6), NumNode(7)] - def test_mul_node_expand(self): - a = Variable(None, 5, 7) - m = MulNode(a, 3) - assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)] + def test_mul_node_expand(self): + a = Variable(None, 5, 7) + m = MulNode(a, 3) + assert m.expand() == [NumNode(15), NumNode(18), NumNode(21)] - b = Variable("b", 1, 3) - n = MulNode(b, 3) - assert n.expand() == [Variable("b", 1, 3)*3] + b = Variable("b", 1, 3) + n = MulNode(b, 3) + assert n.expand() == [Variable("b", 1, 3) * 3] - def test_sum_node_expand(self): - a = Variable(None, 1, 3) - b = Variable("b", 5, 7) + def test_sum_node_expand(self): + a = Variable(None, 1, 3) + b = Variable("b", 5, 7) - s1 = create_rednode(SumNode, [a, b]) - assert s1.expand() == [Variable.sum([NumNode(i),b]) for i in range(1,4)] + s1 = create_rednode(SumNode, [a, b]) + assert s1.expand() == [Variable.sum([NumNode(i), b]) for i in range(1, 4)] - def test_multi_expand(self): - a = Variable("a", 1, 3) - b = Variable("b", 14, 17) - s1 = create_rednode(SumNode, [a, b]) - # expand increments earlier variables faster than later variables (as specified in the argument) - # this behavior was just copied from before, no idea why this should be true - assert s1.expand((a, b)) == [NumNode(x + y) for x in range(b.min, b.max + 1) for y in range(a.min, a.max + 1)] + def test_multi_expand(self): + a = Variable("a", 1, 3) + b = Variable("b", 14, 17) + s1 = create_rednode(SumNode, [a, b]) + # expand increments earlier variables faster than later variables (as specified in the argument) + # this behavior was just copied from before, no idea why this should be true + assert s1.expand((a, b)) == [ + NumNode(x + y) + for x in range(b.min, b.max + 1) + for y in range(a.min, a.max + 1) + ] - def test_substitute(self): - a = Variable(None, 1, 3) - b = a + 1 - c = b.substitute({a: NumNode(1)}) - assert c == NumNode(2) + def test_substitute(self): + a = Variable(None, 1, 3) + b = a + 1 + c = b.substitute({a: NumNode(1)}) + assert c == NumNode(2) -if __name__ == '__main__': - unittest.main() + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/__init__.py b/tinygrad/__init__.py index 0913ae165..f067f83a8 100644 --- a/tinygrad/__init__.py +++ b/tinygrad/__init__.py @@ -1,8 +1,8 @@ -from tinygrad.tensor import Tensor # noqa: F401 -from tinygrad.jit import TinyJit # noqa: F401 +from tinygrad.tensor import Tensor # noqa: F401 +from tinygrad.jit import TinyJit # noqa: F401 from tinygrad.shape.symbolic import Variable # noqa: F401 -from tinygrad.helpers import dtypes # noqa: F401 +from tinygrad.helpers import dtypes # noqa: F401 # NOTE: these should not be relied on to be stable -from tinygrad.device import Device # noqa: F401 -from tinygrad.helpers import GlobalCounters # noqa: F401 \ No newline at end of file +from tinygrad.device import Device # noqa: F401 +from tinygrad.helpers import GlobalCounters # noqa: F401 diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 3bcbc8066..703148d93 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,596 +2,1204 @@ from __future__ import annotations import os, math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union from tinygrad.lazy import vars_from_ast -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps +from tinygrad.ops import ( + LazyOp, + FlopCounter, + get_lazyop_info, + UnaryOps, + BinaryOps, + ReduceOps, + MemBuffer, + ConstBuffer, + BufferOps, +) from tinygrad.device import Device, Compiled -from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up +from tinygrad.helpers import ( + dedup, + dtypes, + colored, + ImageDType, + DType, + ansilen, + getenv, + prod, + DEBUG, + round_up, +) from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import sint from tinygrad.shape.view import View, strides_for_shape from dataclasses import dataclass from enum import Enum, auto + class OptOps(Enum): - UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702 - def __lt__(self, x:OptOps): return self.value < x.value + UPCAST = auto() + UPCASTMID = auto() + UNROLL = auto() + LOCAL = auto() + LASTLOCAL = auto() + GROUP = auto() + GROUPTOP = auto() + NOLOCALS = auto() + PADTO = auto() # noqa: E702 + + def __lt__(self, x: OptOps): + return self.value < x.value + @dataclass(frozen=True, order=True) class Opt: - op: OptOps - axis: Optional[int] = None - amt: Optional[int] = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" + op: OptOps + axis: Optional[int] = None + amt: Optional[int] = None + + def __repr__(self): + return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" + @dataclass(frozen=True) class TensorCore: - device: str - dims: List[int] - dtype_in: DType - dtype_out: DType - threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure - upcast_dim: int # which TC dim to upcast - thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim - thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim - arch: Optional[str] = None - def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>" + device: str + dims: List[int] + dtype_in: DType + dtype_out: DType + threads: List[ + Tuple[int, int] + ] # list of (TC dim,amt) that construct the warp thread structure + upcast_dim: int # which TC dim to upcast + thread_local_aliases: List[ + List[List[int]] + ] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim + thread_local_sizes: List[ + int + ] # in each thread, the number of elements stored in registers for each TC dim + arch: Optional[str] = None + + def __str__(self): + return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>" + tensor_cores: Dict[str, List[TensorCore]] = { - "METAL": [ - TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), - TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), - ], - "HIP": [ - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), - TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), - ] + "METAL": [ + TensorCore( + device="METAL", + dims=[8, 8, 8], + dtype_in=dtypes.float, + dtype_out=dtypes.float, + upcast_dim=0, + threads=[(0, 2), (1, 4), (0, 2), (1, 2)], + thread_local_sizes=[2, 2, 2], + thread_local_aliases=[ + [[4], [0], [2], [0], [-1, 1, 3], [0]], + [[0], [3], [0], [1], [2, 4], [-1]], + [[4], [3], [2], [1], [0], [-1]], + ], + arch="arm64", + ), + TensorCore( + device="METAL", + dims=[8, 8, 8], + dtype_in=dtypes.half, + dtype_out=dtypes.half, + upcast_dim=0, + threads=[(0, 2), (1, 4), (0, 2), (1, 2)], + thread_local_sizes=[2, 2, 2], + thread_local_aliases=[ + [[4], [0], [2], [0], [-1, 1, 3], [0]], + [[0], [3], [0], [1], [2, 4], [-1]], + [[4], [3], [2], [1], [0], [-1]], + ], + arch="arm64", + ), + ], + "HIP": [ + TensorCore( + device="HIP", + dims=[16, 16, 16], + dtype_in=dtypes.half, + dtype_out=dtypes.float, + upcast_dim=1, + threads=[(0, 16), (1, 2)], + thread_local_sizes=[16, 16, 8], + thread_local_aliases=[ + [[0], [0], [-1], [1]], + [[0], [1], [-1], [0]], + [[0], [1], [0], [2, -1]], + ], + ), + TensorCore( + device="HIP", + dims=[16, 16, 16], + dtype_in=dtypes.half, + dtype_out=dtypes.half, + upcast_dim=1, + threads=[(0, 16), (1, 2)], + thread_local_sizes=[16, 16, 8], + thread_local_aliases=[ + [[0], [0], [-1], [1]], + [[0], [1], [-1], [0]], + [[0], [1], [0], [2, -1]], + ], + ), + ], } + class LocalBuffer(NamedTuple): - name: str - size: int - dtype: DType = dtypes.float32 - realized: None = None - def __str__(self): return f"localbuffer<{self.name}[{self.size}]>" + name: str + size: int + dtype: DType = dtypes.float32 + realized: None = None + + def __str__(self): + return f"localbuffer<{self.name}[{self.size}]>" + class LinearizerOptions(NamedTuple): - device: str = "" - # TODO: make this generic with a list of supported types - supports_float4: bool = True - supports_float4_alu: bool = True - has_local: bool = True - has_shared: bool = True - # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered - global_max: Optional[List[int]] = None - local_max: Optional[List[int]] = None + device: str = "" + # TODO: make this generic with a list of supported types + supports_float4: bool = True + supports_float4_alu: bool = True + has_local: bool = True + has_shared: bool = True + # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered + global_max: Optional[List[int]] = None + local_max: Optional[List[int]] = None + class Kernel: - def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None): - self.opts = opts if opts else (cast(Compiled, Device[Device.DEFAULT]).linearizer_opts if isinstance(Device[Device.DEFAULT], Compiled) else LinearizerOptions()) - self.ast = ast - assert ast.op == BufferOps.STORE, f"kernels must have a store as the output, got {ast.op}" + def __init__(self, ast: LazyOp, opts: Optional[LinearizerOptions] = None): + self.opts = ( + opts + if opts + else ( + cast(Compiled, Device[Device.DEFAULT]).linearizer_opts + if isinstance(Device[Device.DEFAULT], Compiled) + else LinearizerOptions() + ) + ) + self.ast = ast + assert ( + ast.op == BufferOps.STORE + ), f"kernels must have a store as the output, got {ast.op}" - # fetch lazyop info - self.info: FlopCounter = get_lazyop_info(self.ast) + # fetch lazyop info + self.info: FlopCounter = get_lazyop_info(self.ast) - # there's only allowed to be one reduceop - reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps] - assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" - self.reduceop = reduceops[0] if reduceops else None + # there's only allowed to be one reduceop + reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps] + assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" + self.reduceop = reduceops[0] if reduceops else None - # create new shapetrackers inside this kernel, we will permute them - self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps]) - assert isinstance(self.bufs[0], MemBuffer) and self.bufs[0].idx == 0, f"buffer 0 is not the store buffer {self.bufs[0]}" + # create new shapetrackers inside this kernel, we will permute them + self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = dedup( + [x.arg for x in self.ast.get_lazyops() if x.op in BufferOps] + ) + assert ( + isinstance(self.bufs[0], MemBuffer) and self.bufs[0].idx == 0 + ), f"buffer 0 is not the store buffer {self.bufs[0]}" - # get earlybufs, before the one reduce op - self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else [] - self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 + # get earlybufs, before the one reduce op + self.earlybufs = ( + [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] + if self.reduceop + else [] + ) + self.full_buf_index: int = ( + self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 + ) - # create the (permuted) shapetrackers - self.sts: List[ShapeTracker] = [x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)] + # create the (permuted) shapetrackers + self.sts: List[ShapeTracker] = [ + x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs) + ] - # move all reduce axes to the end - reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) - permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]) - self.reshape_and_permute(None, permute) + # move all reduce axes to the end + reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape))) + permute = tuple( + [i for i, (s, n) in reduce if s == n] + + [i for i, (s, n) in reduce if s != n] + ) + self.reshape_and_permute(None, permute) - # parameters for optimization - self.applied_opts: List[Opt] = [] - self.group_for_reduce: List[int] = [] - self.upcasted: int = 0 - self.local_dims: int = 0 - self.local_alias: Dict[int, LocalBuffer] = {} - self.tensor_core: Optional[TensorCore] = None - self.dont_use_locals: bool = False + # parameters for optimization + self.applied_opts: List[Opt] = [] + self.group_for_reduce: List[int] = [] + self.upcasted: int = 0 + self.local_dims: int = 0 + self.local_alias: Dict[int, LocalBuffer] = {} + self.tensor_core: Optional[TensorCore] = None + self.dont_use_locals: bool = False - # group simplifies - self.simplify_ones() - self.simplify_merge_adjacent() + # group simplifies + self.simplify_ones() + self.simplify_merge_adjacent() - # cache - self.applied_opts_cache: Optional[List[Opt]] = None + # cache + self.applied_opts_cache: Optional[List[Opt]] = None - def copy(self): - ret = type(self).__new__(type(self)) + def copy(self): + ret = type(self).__new__(type(self)) - # base linearizer params - ret.opts, ret.ast = self.opts, self.ast + # base linearizer params + ret.opts, ret.ast = self.opts, self.ast - # things downstream of the AST - # NOTE: we copy bufs for local buffers and sts for optimizations - ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = \ - self.info, self.reduceop, self.bufs[:], self.earlybufs, self.full_buf_index, self.sts[:] + # things downstream of the AST + # NOTE: we copy bufs for local buffers and sts for optimizations + ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = ( + self.info, + self.reduceop, + self.bufs[:], + self.earlybufs, + self.full_buf_index, + self.sts[:], + ) - # parameters for optimizations - ret.applied_opts, ret.group_for_reduce, ret.upcasted, ret.local_dims, ret.local_alias, ret.tensor_core, ret.dont_use_locals = \ - self.applied_opts[:], self.group_for_reduce[:], self.upcasted, self.local_dims, self.local_alias.copy(), self.tensor_core, self.dont_use_locals + # parameters for optimizations + ( + ret.applied_opts, + ret.group_for_reduce, + ret.upcasted, + ret.local_dims, + ret.local_alias, + ret.tensor_core, + ret.dont_use_locals, + ) = ( + self.applied_opts[:], + self.group_for_reduce[:], + self.upcasted, + self.local_dims, + self.local_alias.copy(), + self.tensor_core, + self.dont_use_locals, + ) - # uncached since linearize didn't run - ret.applied_opts_cache = None + # uncached since linearize didn't run + ret.applied_opts_cache = None - return ret + return ret - @property - def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] + @property + def membufs(self) -> List[MemBuffer]: + return [x for x in self.bufs if isinstance(x, MemBuffer)] - # TODO: these need more tests or it might silently be no-op - def shape_offsets(self, i:int): return itertools.product(*[list(range(cast(int, s))) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()] - def float4_axis(self, i:int): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0] + # TODO: these need more tests or it might silently be no-op + def shape_offsets(self, i: int): + return ( + itertools.product( + *[ + list(range(cast(int, s))) + for s in self.sts[i].shape[self.shape_len - self.upcasted :][::-1] + ] + ) + if self.upcasted > 0 + else [tuple()] + ) - def upcasted_axis(self, i:int): - return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:], - self.sts[i].real_strides()[self.shape_len-self.upcasted:], - [x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])])) + def float4_axis(self, i: int): + return [ + x - (self.shape_len - self.upcasted) + for x in self.sts[i].unit_stride_axes() + if x >= self.shape_len - self.upcasted and self.sts[i].shape[x] % 4 == 0 + ] - # TODO: is there a better way to write this? - def acc_offsets(self, i:int) -> List[int]: - if self.upcasted == 0: return [0] - upcasted_i = self.upcasted_axis(i) - acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))] - return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])] + def upcasted_axis(self, i: int): + return list( + zip( + self.sts[i].shape[self.shape_len - self.upcasted :], + self.sts[i].real_strides()[self.shape_len - self.upcasted :], + [ + x != y + for x, y in zip( + self.sts[0].shape[self.shape_len - self.upcasted :], + self.full_shape[self.shape_len - self.upcasted :], + ) + ], + ) + ) - def get_upcast_dim(self, i:int) -> List[int]: - should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) - return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] + # TODO: is there a better way to write this? + def acc_offsets(self, i: int) -> List[int]: + if self.upcasted == 0: + return [0] + upcasted_i = self.upcasted_axis(i) + acc_strides = [ + x * (1 - upcasted_i[::-1][i][2]) + for i, x in enumerate( + strides_for_shape(tuple(1 if r else s for s, _, r in upcasted_i[::-1])) + ) + ] + return [ + sum(t) + for t in itertools.product( + *[ + [y * acc_strides[i] for y in range(x[0])] + for i, x in enumerate(upcasted_i[::-1]) + ] + ) + ] - @property - def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True) + def get_upcast_dim(self, i: int) -> List[int]: + should_upcast = self.opts.supports_float4 and ( + self.bufs[i].dtype in [dtypes.float32, dtypes.float16] + or isinstance(self.bufs[i].dtype, ImageDType) + ) + return [ + x + for x in self.sts[i].unit_stride_axes() + if should_upcast + and x >= self.shape_len - self.upcasted + and self.sts[i].shape[x] > 1 + ] - @property - def output_shape(self) -> Tuple[sint, ...]: return self.sts[0].shape + @property + def first_reduce(self) -> int: + return [ + x != y + for x, y in zip( + self.sts[0].shape[: self.shape_len - self.upcasted] + (0,), + self.full_shape[: self.shape_len - self.upcasted] + (1,), + ) + ].index(True) - @property - def full_shape(self) -> Tuple[sint, ...]: return self.sts[self.full_buf_index].shape + @property + def output_shape(self) -> Tuple[sint, ...]: + return self.sts[0].shape - @property - def full_unupcasted_shape(self) -> Tuple[sint, ...]: return self.full_shape[:self.shape_len-self.upcasted] + @property + def full_shape(self) -> Tuple[sint, ...]: + return self.sts[self.full_buf_index].shape - @property - def shape_len(self) -> int: return len(self.sts[0].shape) + @property + def full_unupcasted_shape(self) -> Tuple[sint, ...]: + return self.full_shape[: self.shape_len - self.upcasted] - @property - def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]] + @property + def shape_len(self) -> int: + return len(self.sts[0].shape) - @property - def global_dims(self) -> int: return self.first_reduce-self.local_dims + @property + def upcast_in_mid_reduce_axes(self) -> List[int]: + return [ + j + for j in range( + self.first_reduce, self.first_reduce + len(self.group_for_reduce) + ) + if self.full_shape[j] == self.sts[0].shape[j] + ] - # there's eight chunks of the shape - # blue -- global dims - # cyan -- local dims (warp ones first) - # *** self.first_reduce - # green -- reduce-local dims - # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes) - # red -- reduce loops - # *** self.upcasted - # purple -- reduce upcasted - # yellow -- normal upcasted dimensions - def colors(self) -> List[str]: - # first non local non reduce dims are global (blue) - colors = ["blue"] * self.global_dims if not self.dont_use_locals else ["BLUE"] * self.global_dims - # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan) - colors += ["cyan"] * self.local_dims - # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green) - colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))] - # between first_reduce + group_for_reduce and upcasted, they are reduce (red) - colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce))) - # upcasted dimensions are reduce (magenta) or normal (yellow) - colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)] - assert len(colors) == self.shape_len, "colors size mismatch" - return colors + @property + def global_dims(self) -> int: + return self.first_reduce - self.local_dims - def colored_shape(self, pad:Optional[int]=None, dense=False) -> str: - ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors())) - if pad: ret += ' '*(pad-ansilen(ret)) - return ret + # there's eight chunks of the shape + # blue -- global dims + # cyan -- local dims (warp ones first) + # *** self.first_reduce + # green -- reduce-local dims + # white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes) + # red -- reduce loops + # *** self.upcasted + # purple -- reduce upcasted + # yellow -- normal upcasted dimensions + def colors(self) -> List[str]: + # first non local non reduce dims are global (blue) + colors = ( + ["blue"] * self.global_dims + if not self.dont_use_locals + else ["BLUE"] * self.global_dims + ) + # after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan) + colors += ["cyan"] * self.local_dims + # between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green) + colors += [ + "white" if i in self.upcast_in_mid_reduce_axes else "green" + for i in range( + self.first_reduce, self.first_reduce + len(self.group_for_reduce) + ) + ] + # between first_reduce + group_for_reduce and upcasted, they are reduce (red) + colors += ["red"] * ( + (self.shape_len - self.upcasted) + - (self.first_reduce + len(self.group_for_reduce)) + ) + # upcasted dimensions are reduce (magenta) or normal (yellow) + colors += [ + "magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" + for i in range(self.shape_len - self.upcasted, self.shape_len) + ] + assert len(colors) == self.shape_len, "colors size mismatch" + return colors - # ******************** base simplifiers ******************** + def colored_shape(self, pad: Optional[int] = None, dense=False) -> str: + ret = " ".join( + colored(s, color) + for s, color in zip( + [ + f"{s:4d}" if isinstance(s, int) and not dense else s + for s in self.full_shape + ], + self.colors(), + ) + ) + if pad: + ret += " " * (pad - ansilen(ret)) + return ret - # apply reshape and permute to all shapetrackers - def reshape_and_permute(self, new_shape_fxn, axis): - new_sts = [] - for st in self.sts: - if new_shape_fxn is not None: st = st.reshape(tuple(new_shape_fxn(st.shape))) - if axis is not None: st = st.permute(tuple(axis)) - new_sts.append(st) - self.sts = new_sts + # ******************** base simplifiers ******************** - # drops the final dimension - def upcast(self): - assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1" - self.upcasted += 1 + # apply reshape and permute to all shapetrackers + def reshape_and_permute(self, new_shape_fxn, axis): + new_sts = [] + for st in self.sts: + if new_shape_fxn is not None: + st = st.reshape(tuple(new_shape_fxn(st.shape))) + if axis is not None: + st = st.permute(tuple(axis)) + new_sts.append(st) + self.sts = new_sts - # axis : the axis to pull from - # amount : the amount to take - # top : if you want to pull that amount from the top - # insert_before : place to insert the new stuff - def shift_to(self, axis, amount, top=False, insert_before=None): - if insert_before is None: insert_before = self.shape_len - move_axis = axis if top else axis+1 - if move_axis < insert_before: insert_before += 1 - self.reshape_and_permute( - lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]), - [i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis]) + # drops the final dimension + def upcast(self): + assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1" + self.upcasted += 1 - # ******************** complex simplifiers ******************** + # axis : the axis to pull from + # amount : the amount to take + # top : if you want to pull that amount from the top + # insert_before : place to insert the new stuff + def shift_to(self, axis, amount, top=False, insert_before=None): + if insert_before is None: + insert_before = self.shape_len + move_axis = axis if top else axis + 1 + if move_axis < insert_before: + insert_before += 1 + self.reshape_and_permute( + lambda x: list(x[0:axis]) + + ( + ([amount, x[axis] // amount] if top else [x[axis] // amount, amount]) + if x[axis] > 1 + else [1, 1] + ) + + list(x[axis + 1 :]), + [i for i in range(insert_before) if i != move_axis] + + [move_axis] + + [i for i in range(insert_before, self.shape_len + 1) if i != move_axis], + ) - def simplify_ones(self) -> bool: - # remove places where the shape is all ones - # TODO: this should be factored in to multi shape stride - if self.shape_len == 0: return False - all_ones = [s==1 for s in self.full_shape] - self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce]) - self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) - self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) - return any(all_ones) + # ******************** complex simplifiers ******************** - def simplify_merge_adjacent(self): - if self.shape_len == 0: return - shapes, strides = [x.shape for x in self.sts], [x.real_strides() for x in self.sts] + def simplify_ones(self) -> bool: + # remove places where the shape is all ones + # TODO: this should be factored in to multi shape stride + if self.shape_len == 0: + return False + all_ones = [s == 1 for s in self.full_shape] + self.local_dims -= sum( + all_ones[self.first_reduce - self.local_dims : self.first_reduce] + ) + self.upcasted -= sum(all_ones[self.shape_len - self.upcasted :]) + self.reshape_and_permute( + lambda shape: [x for i, x in enumerate(shape) if not all_ones[i]], None + ) + return any(all_ones) - # if it's an image, insert fake strides such that this fusion doesn't happen across image axes - if isinstance(self.bufs[0].dtype, ImageDType): - base_shape = self.bufs[0].dtype.shape - if shape_idx_groups := get_contraction(self.output_shape, base_shape): - special_strides: Tuple[int, ...] = tuple() - for i,g in enumerate(shape_idx_groups): - shape_piece = tuple(self.output_shape[x] for x in g) - assert prod(shape_piece) == base_shape[i], f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" - special_strides += strides_for_shape(shape_piece) - # adding the fake image shape - shapes.append(self.output_shape) - strides.append(special_strides) + def simplify_merge_adjacent(self): + if self.shape_len == 0: + return + shapes, strides = [x.shape for x in self.sts], [ + x.real_strides() for x in self.sts + ] - # merge dimensions if we can, multi get_shape_strides - # TODO: does this always preserve the reduce dimension, NO - # TODO: move this into shapetracker, with tests! - rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] - for i in range(1, len(shapes[0])): - can_merge = [] - for j in range(len(shapes)): - # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) - # more can merge than this - mergeable = all(can_merge) and i != self.first_reduce - for j in range(len(shapes)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) - else: rets[j].append((shapes[j][i], strides[j][i])) + # if it's an image, insert fake strides such that this fusion doesn't happen across image axes + if isinstance(self.bufs[0].dtype, ImageDType): + base_shape = self.bufs[0].dtype.shape + if shape_idx_groups := get_contraction(self.output_shape, base_shape): + special_strides: Tuple[int, ...] = tuple() + for i, g in enumerate(shape_idx_groups): + shape_piece = tuple(self.output_shape[x] for x in g) + assert ( + prod(shape_piece) == base_shape[i] + ), f"get_contraction was wrong? {shape_piece} != {base_shape[i]}" + special_strides += strides_for_shape(shape_piece) + # adding the fake image shape + shapes.append(self.output_shape) + strides.append(special_strides) - # do the reshapes - for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) + # merge dimensions if we can, multi get_shape_strides + # TODO: does this always preserve the reduce dimension, NO + # TODO: move this into shapetracker, with tests! + rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] + for i in range(1, len(shapes[0])): + can_merge = [] + for j in range(len(shapes)): + # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case + can_merge.append( + strides[j][i] is not None + and ( + ( + strides[j][i] != 0 + and rets[j][-1][1] + == shapes[j][i] * cast(int, strides[j][i]) + ) + or (strides[j][i] == 0 and rets[j][-1][1] == 0) + ) + ) + # more can merge than this + mergeable = all(can_merge) and i != self.first_reduce + for j in range(len(shapes)): + if mergeable: + rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) + else: + rets[j].append((shapes[j][i], strides[j][i])) - # ******************** GPU simplifiers ******************** + # do the reshapes + for i, x in enumerate(rets[: len(self.sts)]): + self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x])) - def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: - new_shape,dims = list(x), len(x) - for i in range(dims): - next_idx = (i + 1) % dims - while new_shape[i] > max_size[i]: - new_shape[i] = new_shape[i] // 2 - if (new_shape[next_idx] <= max_size[next_idx]): - new_shape[next_idx] = new_shape[next_idx] * 2 + # ******************** GPU simplifiers ******************** + + def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]: + new_shape, dims = list(x), len(x) + for i in range(dims): + next_idx = (i + 1) % dims + while new_shape[i] > max_size[i]: + new_shape[i] = new_shape[i] // 2 + if new_shape[next_idx] <= max_size[next_idx]: + new_shape[next_idx] = new_shape[next_idx] * 2 + else: + next_idx = (next_idx + 1) % dims + new_shape[next_idx] = new_shape[next_idx] * 2 + return tuple(new_shape) + + def limit_dims_to_max(self, global_max: List[int], local_max: List[int]): + # Check the global allocation limit, current the global_size will be flipped during codegen + # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write + global_dims = self.first_reduce - self.local_dims + if global_dims > 0: + if global_max: + tmp = global_max[:global_dims] + ( + local_max[: self.local_dims] if local_max else [] + ) + if max(global_max) < max(self.full_shape[:global_dims]): + self.reshape_and_permute( + lambda x: self._limit_size( + x, tmp + [math.inf] * (len(self.full_shape) - len(tmp)) + ), + None, + ) + assert max(global_max) >= max( + self.full_shape[:global_dims] + ), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" + for i in range(global_dims - 1): + if i < len(global_max) and self.full_shape[i] > global_max[i]: + order = list(range(len(self.full_shape))) + order[i], order[global_dims - 1] = order[global_dims - 1], order[i] + self.reshape_and_permute(None, order) + if DEBUG >= 3: + print( + "permuted global dim", + order, + "due to allocation exceeds global limit", + ) + + def alias_buffer(self, i, pattern): + assert len(pattern) == len( + self.sts[i].shape + ), f"must include a pattern for each shape {pattern} {self.sts[i].shape}" + + bst = 1 + real_strides = self.sts[i].real_strides() + shp, stride = [ + (s if p != 0 else 1) for s, p in zip(self.sts[i].shape, pattern) + ], [0] * len(pattern) + for priority in range( + 1, max(pattern) + 1 + ): # priority. 0 is non local and ignored + for j, p in enumerate(pattern): + if priority == p and real_strides[j] != 0: + stride[j] = bst + bst *= shp[j] + + self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),))) + self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size())) + if DEBUG >= 4: + print("aliasing buffer", self.sts[i]) + self.local_alias[i] = cast(LocalBuffer, self.bufs[-1]) + + # ******************** high level optimizers ******************** + + def apply_tensor_cores( + self, use_tensor_cores=1, extra_opts: Optional[List[Opt]] = None + ): + if ( + use_tensor_cores + and self.opts.has_local + and self.reduceop + and self.reduceop.op == ReduceOps.SUM + and self.opts.device in tensor_cores + ): + for tc in tensor_cores[self.opts.device]: + if not ( + (tc.arch is None or tc.arch == os.uname().machine) + and isinstance(self.reduceop.src[0], LazyOp) + ): + continue + has_cast = tc.dtype_in != tc.dtype_out + + if has_cast and not ( + isinstance(self.reduceop.src[0], LazyOp) + and self.reduceop.src[0].op == UnaryOps.CAST + and self.reduceop.src[0].arg[0] == tc.dtype_out + ): + continue + mul_op = ( + self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] + ) + + if not (isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): + continue + if not ( + isinstance(mul_op.src[0], LazyOp) + and mul_op.src[0].op == BufferOps.LOAD + and mul_op.src[0].arg.dtype == tc.dtype_in + ): + continue + if not ( + isinstance(mul_op.src[1], LazyOp) + and mul_op.src[1].op == BufferOps.LOAD + and mul_op.src[1].arg.dtype == tc.dtype_in + ): + continue + buf0, buf1 = self.bufs.index( + cast(MemBuffer, mul_op.src[0].arg) + ), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg)) + buf0_strides, buf1_strides = ( + self.sts[buf0].real_strides(), + self.sts[buf1].real_strides(), + ) + axis_buf0 = [ + (i, self.full_shape[i], buf1_strides[i]) + for i, s in enumerate(buf0_strides[: self.first_reduce]) + if s == 0 and self.full_shape[i] % tc.dims[0] == 0 + ] + axis_buf1 = [ + (i, self.full_shape[i], buf0_strides[i]) + for i, s in enumerate(buf1_strides[: self.first_reduce]) + if s == 0 and self.full_shape[i] % tc.dims[1] == 0 + ] + + if not ( + axis_buf0 + and axis_buf1 + and self.full_shape[self.first_reduce] % tc.dims[2] == 0 + and self.full_shape[self.first_reduce] >= tc.dims[2] + and (self.shape_len - self.first_reduce) == 1 + ): + continue + + if DEBUG >= 3: + print("TENSOR CORES", axis_buf0, axis_buf1, tc) + + s0, s1 = ( + axis_buf0[-1][0], + axis_buf1[-1][0], + ) # TODO: select axis in smart way + s0_exists, s1_exists = True, True + assert ( + s0 != s1 + and self.full_shape[s0] % tc.dims[0] == 0 + and self.full_shape[s1] % tc.dims[1] == 0 + ) + + def fix(needed, ax): + nonlocal s0, s1, s0_exists, s1_exists + if not needed: + return + if s0_exists and ax == s0: + if s1_exists and s0 < s1: + s1 -= 1 + s0_exists = False + elif s1_exists and ax == s1: + if s0_exists and s1 < s0: + s0 -= 1 + s1_exists = False + + # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern + self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) + self.apply_opt( + Opt( + OptOps.UPCAST, + s0 if tc.upcast_dim == 0 else s1, + (tc.dims[0] * tc.dims[2]) // prod([a[1] for a in tc.threads]), + ) + ) + for tc_dim, tc_amt in tc.threads: + fix( + self.apply_opt( + Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt) + ), + s0 if tc_dim == 0 else s1, + ) + + # assert tensor core and prevent extra_opts from altering the key shape structure + if use_tensor_cores == 1: + self.tensor_core = tc # TC=2 will do the shape ops without the WMMA + + if extra_opts is not None: + for opt in extra_opts: + self.apply_opt(opt) + else: + # hand-coded TC opts + if s1_exists: + s1_div = [ + upc + for upc in [5, 4, 3, 2, 1] + if self.full_shape[s1] % upc == 0 + ][0] + if s1_div != 1: + fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1) + if s0_exists: + s0_div = [ + upc + for upc in [5, 4, 3, 2, 1] + if self.full_shape[s0] % upc == 0 + ][0] + if s0_div != 1: + fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0) + if self.tensor_core and s0_exists: + for upc in [4, 2]: + if self.full_shape[s0] % upc == 0: + self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) + break + + # alias buffer + alias_pattern = ( + [0] * (self.global_dims + (self.local_dims - len(tc.threads))) + + [2] * (len(tc.threads)) + + [0] * (self.shape_len - self.upcasted - self.first_reduce) + + [1, 1] + + [3] * (self.upcasted - 2) + ) + self.alias_buffer(buf0, alias_pattern) + self.alias_buffer(buf1, alias_pattern) + return True + return False + + def apply_opt(self, opt: Opt): + assert not self.dont_use_locals or opt.op not in { + OptOps.LOCAL, + OptOps.LASTLOCAL, + OptOps.GROUP, + OptOps.GROUPTOP, + OptOps.UPCASTMID, + }, "not using locals" + self.applied_opts.append(opt) + if opt.axis is not None: + axis = opt.axis + ( + self.first_reduce + if opt.op == OptOps.UNROLL + else ( + self.first_reduce + len(self.group_for_reduce) + if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP + else 0 + ) + ) else: - next_idx = (next_idx + 1) % dims - new_shape[next_idx] = new_shape[next_idx] * 2 - return tuple(new_shape) - - def limit_dims_to_max(self, global_max: List[int], local_max: List[int]): - # Check the global allocation limit, current the global_size will be flipped during codegen - # and then padded right with 1s if its length < 3 which makes this part a bit awkward to write - global_dims = self.first_reduce-self.local_dims - if global_dims > 0: - if global_max: - tmp = global_max[:global_dims] + (local_max[:self.local_dims] if local_max else []) - if max(global_max) < max(self.full_shape[:global_dims]): self.reshape_and_permute(lambda x: self._limit_size(x, tmp + [math.inf] * (len(self.full_shape)-len(tmp))), None) - assert max(global_max) >= max(self.full_shape[:global_dims]), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}" - for i in range(global_dims-1): - if i < len(global_max) and self.full_shape[i] > global_max[i]: - order = list(range(len(self.full_shape))) - order[i], order[global_dims-1] = order[global_dims-1], order[i] - self.reshape_and_permute(None, order) - if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit") - - def alias_buffer(self, i, pattern): - assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}" - - bst = 1 - real_strides = self.sts[i].real_strides() - shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern) - for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored - for j,p in enumerate(pattern): - if priority == p and real_strides[j] != 0: - stride[j] = bst - bst *= shp[j] - - self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),))) - self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size())) - if DEBUG >= 4: print("aliasing buffer", self.sts[i]) - self.local_alias[i] = cast(LocalBuffer, self.bufs[-1]) - - # ******************** high level optimizers ******************** - - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None): - if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: - for tc in tensor_cores[self.opts.device]: - if not((tc.arch is None or tc.arch == os.uname().machine) and isinstance(self.reduceop.src[0], LazyOp)): continue - has_cast = tc.dtype_in != tc.dtype_out - - if has_cast and not(isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue - mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] - - if not(isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL): continue - if not(isinstance(mul_op.src[0], LazyOp) and mul_op.src[0].op == BufferOps.LOAD and mul_op.src[0].arg.dtype == tc.dtype_in): continue - if not(isinstance(mul_op.src[1], LazyOp) and mul_op.src[1].op == BufferOps.LOAD and mul_op.src[1].arg.dtype == tc.dtype_in): continue - buf0, buf1 = self.bufs.index(cast(MemBuffer, mul_op.src[0].arg)), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg)) - buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] - - if not(axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%tc.dims[2] == 0 and self.full_shape[self.first_reduce] >= tc.dims[2] and (self.shape_len-self.first_reduce) == 1): continue - - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) - - s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0] # TODO: select axis in smart way - s0_exists, s1_exists = True, True - assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0 - def fix(needed, ax): - nonlocal s0, s1, s0_exists, s1_exists - if not needed: return - if s0_exists and ax == s0: - if s1_exists and s0 < s1: s1 -= 1 - s0_exists = False - elif s1_exists and ax == s1: - if s0_exists and s1 < s0: s0 -= 1 - s1_exists = False - - # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern - self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) - self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) - for (tc_dim, tc_amt) in tc.threads: - fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) - - # assert tensor core and prevent extra_opts from altering the key shape structure - if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA - - if extra_opts is not None: - for opt in extra_opts: - self.apply_opt(opt) + axis = -1 + if opt.amt is not None: + amt = opt.amt if opt.amt != 0 else self.full_shape[axis] + assert ( + isinstance(amt, int) and amt != 1 + ), "shift/padto of amt 1 or Node is meaningless" + if opt.op != OptOps.PADTO: + assert self.full_shape[axis] % amt == 0, "no longer valid shift" else: - # hand-coded TC opts - if s1_exists: - s1_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s1]%upc == 0][0] - if s1_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1) - if s0_exists: - s0_div = [upc for upc in [5,4,3,2,1] if self.full_shape[s0]%upc == 0][0] - if s0_div != 1: fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0) - if self.tensor_core and s0_exists: - for upc in [4,2]: - if self.full_shape[s0] % upc == 0: - self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) - break + amt = -1 + if opt.op == OptOps.LOCAL: # cyan + assert self.opts.has_local, "target does not support local" + assert axis < self.first_reduce, "can't local a reduce" + assert not (self.tensor_core), "can't local with tensor cores" + self.shift_to(axis, amt, insert_before=self.first_reduce) + self.local_dims += 1 + elif opt.op == OptOps.LASTLOCAL: # cyan + assert self.opts.has_local, "target does not support local" + assert axis < self.first_reduce, "can't local a reduce" + self.shift_to(axis, amt, insert_before=self.first_reduce - self.local_dims) + self.local_dims += 1 + elif opt.op == OptOps.GROUP: # green + assert ( + self.opts.has_local and self.opts.has_shared + ), "target does not support local or shared mem" + assert ( + axis >= self.first_reduce + len(self.group_for_reduce) + and axis < self.shape_len - self.upcasted + ), "must be reduce axis to group" + assert not (self.tensor_core), "can't group with tensor cores" + self.shift_to( + axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce) + ) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.GROUPTOP: # green + assert ( + self.opts.has_local and self.opts.has_shared + ), "target does not support local or shared mem" + assert ( + axis >= self.first_reduce + len(self.group_for_reduce) + and axis < self.shape_len - self.upcasted + ), "must be reduce axis to group" + assert not (self.tensor_core), "can't group with tensor cores" + self.shift_to( + axis, + amt, + top=True, + insert_before=self.first_reduce + len(self.group_for_reduce), + ) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.UNROLL: # purple + assert ( + axis < self.shape_len - self.upcasted + ), "can't upcasted already upcasted" + assert amt <= 32, "don't unroll more than 32" + self.shift_to(axis, amt, insert_before=None) + self.upcast() + elif opt.op == OptOps.UPCAST: # yellow + assert axis < self.first_reduce, "upcast is for non-reduce" + assert amt <= 8, "don't upcast more than 8" + self.shift_to(axis, amt, insert_before=None) + self.upcast() + elif opt.op == OptOps.UPCASTMID: # white + assert ( + self.bufs[0].dtype.name.startswith("image") + and not self.float4_axis(0) + and self.group_for_reduce + and self.first_reduce <= 2 + and prod(self.sts[0].shape) > 1 + ), "invalid upcast mid reduce" + axes = self.sts[0].unit_stride_axes() + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + assert axes[0] == axis, "wrong axis" + assert amt == 4, "don't upcast mid anything but 4" + self.shift_to( + axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce) + ) + self.group_for_reduce.append(amt) + elif opt.op == OptOps.NOLOCALS: + assert ( + self.opts.has_local + ), "target does not support local, so this optimization is meaningless" + assert ( + self.local_dims == 0 and len(self.group_for_reduce) == 0 + ), "can't have no locals with locals" + assert not self.dont_use_locals, "already not using locals" + self.dont_use_locals = True + elif opt.op == OptOps.PADTO: + assert not vars_from_ast(self.ast), "does not work with symbolic shape" + assert all( + op.op is not ReduceOps.MAX for op in self.ast.get_lazyops() + ), "cannot pad with MAX" + padded = False + for i, st in enumerate(self.sts): + if self.sts[i].shape[axis] != 1: + assert ( + self.sts[i].shape[axis] > amt // 2 + ), "pad adds more than double the work" + if ( + ru := round_up(self.sts[i].shape[axis], amt) + - self.sts[i].shape[axis] + ): + # pad right seems to be faster + self.sts[i] = st.pad( + ((0, 0),) * axis + + ((0, ru),) + + ((0, 0),) * (len(st.shape) - axis - 1) + ) + padded = True + assert padded, "nothing was padded" + return self.simplify_ones() - # alias buffer - alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) - self.alias_buffer(buf0, alias_pattern) - self.alias_buffer(buf1, alias_pattern) - return True - return False + def hand_coded_optimizations(self): + # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat + MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = ( + getenv("MV_BLOCKSIZE", 4), + getenv("MV_THREADS_PER_ROW", 8), + getenv("MV_ROWS_PER_THREAD", 4), + ) + if ( + self.opts.has_local + and getenv("MV", 1) != 0 + and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) + and self.reduceop + and self.reduceop.op == ReduceOps.SUM + and len(self.full_shape) >= 2 + and self.opts.has_shared + and isinstance(self.reduceop.src[0], LazyOp) + and self.reduceop.src[0].op == BinaryOps.MUL + and self.reduceop.src[0].src[0].op == BufferOps.LOAD + and self.reduceop.src[0].src[1].op == BufferOps.LOAD + ): + buf0 = self.bufs.index(self.reduceop.src[0].src[0].arg) + buf1 = self.bufs.index(self.reduceop.src[0].src[1].arg) + buf0_strides = self.sts[buf0].real_strides() + buf1_strides = self.sts[buf1].real_strides() - def apply_opt(self, opt:Opt): - assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" - self.applied_opts.append(opt) - if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+len(self.group_for_reduce) if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP else 0)) - else: - axis = -1 - if opt.amt is not None: - amt = opt.amt if opt.amt != 0 else self.full_shape[axis] - assert isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless" - if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift" - else: - amt = -1 - if opt.op == OptOps.LOCAL: # cyan - assert self.opts.has_local, "target does not support local" - assert axis < self.first_reduce, "can't local a reduce" - assert not(self.tensor_core), "can't local with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce) - self.local_dims += 1 - elif opt.op == OptOps.LASTLOCAL: # cyan - assert self.opts.has_local, "target does not support local" - assert axis < self.first_reduce, "can't local a reduce" - self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) - self.local_dims += 1 - elif opt.op == OptOps.GROUP: # green - assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem" - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" - assert not(self.tensor_core), "can't group with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.GROUPTOP: # green - assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem" - assert axis >= self.first_reduce + len(self.group_for_reduce) and axis < self.shape_len-self.upcasted, "must be reduce axis to group" - assert not(self.tensor_core), "can't group with tensor cores" - self.shift_to(axis, amt, top=True, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.UNROLL: # purple - assert axis < self.shape_len-self.upcasted, "can't upcasted already upcasted" - assert amt <= 32, "don't unroll more than 32" - self.shift_to(axis, amt, insert_before=None) - self.upcast() - elif opt.op == OptOps.UPCAST: # yellow - assert axis < self.first_reduce, "upcast is for non-reduce" - assert amt <= 8, "don't upcast more than 8" - self.shift_to(axis, amt, insert_before=None) - self.upcast() - elif opt.op == OptOps.UPCASTMID: # white - assert self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce" - axes = self.sts[0].unit_stride_axes() - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - assert axes[0] == axis, "wrong axis" - assert amt == 4, "don't upcast mid anything but 4" - self.shift_to(axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)) - self.group_for_reduce.append(amt) - elif opt.op == OptOps.NOLOCALS: - assert self.opts.has_local, "target does not support local, so this optimization is meaningless" - assert self.local_dims == 0 and len(self.group_for_reduce) == 0, "can't have no locals with locals" - assert not self.dont_use_locals, "already not using locals" - self.dont_use_locals = True - elif opt.op == OptOps.PADTO: - assert not vars_from_ast(self.ast), "does not work with symbolic shape" - assert all(op.op is not ReduceOps.MAX for op in self.ast.get_lazyops()), "cannot pad with MAX" - padded = False - for i,st in enumerate(self.sts): - if self.sts[i].shape[axis] != 1: - assert self.sts[i].shape[axis] > amt//2, "pad adds more than double the work" - if (ru := round_up(self.sts[i].shape[axis], amt) - self.sts[i].shape[axis]): - # pad right seems to be faster - self.sts[i] = st.pad(((0,0),) * axis + ((0,ru),) + ((0,0),) * (len(st.shape)-axis-1)) - padded = True - assert padded, "nothing was padded" - return self.simplify_ones() + def has_expanded_axis(s, st): + return any(x > 1 and y == 0 for x, y in zip(s, st)) - def hand_coded_optimizations(self): - # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat - MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) - if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ - isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ - self.reduceop.src[0].src[0].op == BufferOps.LOAD and self.reduceop.src[0].src[1].op == BufferOps.LOAD: - buf0 = self.bufs.index(self.reduceop.src[0].src[0].arg) - buf1 = self.bufs.index(self.reduceop.src[0].src[1].arg) - buf0_strides = self.sts[buf0].real_strides() - buf1_strides = self.sts[buf1].real_strides() - def has_expanded_axis(s, st): return any(x > 1 and y == 0 for x,y in zip(s,st)) - if buf0_strides[self.first_reduce] == 1 and not (has_expanded_axis(self.sts[buf0].shape, buf0_strides) and has_expanded_axis(self.sts[buf1].shape, buf1_strides)): - for global_idx in range(self.global_dims): - if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: print(f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}") - if MV_THREADS_PER_ROW > 1: - self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) - if MV_BLOCKSIZE > 1: - self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) - if MV_ROWS_PER_THREAD > 1: - self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) + if buf0_strides[self.first_reduce] == 1 and not ( + has_expanded_axis(self.sts[buf0].shape, buf0_strides) + and has_expanded_axis(self.sts[buf1].shape, buf1_strides) + ): + for global_idx in range(self.global_dims): + if ( + self.full_shape[self.first_reduce] % MV_THREADS_PER_ROW == 0 + and self.full_shape[global_idx] + % (MV_BLOCKSIZE * MV_ROWS_PER_THREAD) + == 0 + ): + if DEBUG >= 3: + print( + f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}" + ) + if MV_THREADS_PER_ROW > 1: + self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) + if MV_BLOCKSIZE > 1: + self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) + if MV_ROWS_PER_THREAD > 1: + self.apply_opt( + Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD) + ) + return + + if ( + self.opts.has_local + and self.opts.has_shared + and all(isinstance(s, int) for s in self.sts[0].shape[: self.first_reduce]) + ): + # are we grouping? (requires local shape support) + if ( + not self.float4_axis(0) + and self.first_reduce <= 2 + and self.first_reduce + 1 <= self.shape_len + and prod(self.sts[0].shape[: self.first_reduce]) <= 2048 + ): + # TODO: use 1024 if it's allowed in a smarter way + for sz in ( + ([256, 16]) + if prod(self.sts[0].shape[: self.first_reduce]) <= 32 + else [16] + ): + if all( + st.shape[self.first_reduce] % sz == 0 + or st.shape[self.first_reduce] == 1 + for st in self.sts + ): + self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) + break + + # are we upcasting in mid reduce? (only for images) + if ( + self.bufs[0].dtype.name.startswith("image") + and not self.float4_axis(0) + and self.group_for_reduce + and self.first_reduce <= 2 + and prod(self.sts[0].shape) > 1 + ): + axes = self.sts[0].unit_stride_axes() + assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" + if self.sts[0].shape[axes[0]] % 4 == 0: + self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4)) + + # upcast float4 images + for buf_index, buf in enumerate(self.bufs): + unit_stride_axes_mul_4 = [ + i + for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) + if self.sts[buf_index].shape[i] % 4 == 0 + ] + if buf.dtype.__class__ is ImageDType: + # assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" + if ( + len(unit_stride_axes_mul_4) + and all( + x < (self.shape_len - self.upcasted) + for x in unit_stride_axes_mul_4 + ) + and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes + ): + if unit_stride_axes_mul_4[0] < self.first_reduce: + self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + else: + self.apply_opt( + Opt( + OptOps.UNROLL, + unit_stride_axes_mul_4[0] - self.first_reduce, + 4, + ) + ) + + # no more opt if we are grouping + if self.group_for_reduce: return - if self.opts.has_local and self.opts.has_shared and all(isinstance(s, int) for s in self.sts[0].shape[:self.first_reduce]): - # are we grouping? (requires local shape support) - if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: - # TODO: use 1024 if it's allowed in a smarter way - for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): - if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): - self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) - break + # **** below this line need to be optional and benchmarked **** - # are we upcasting in mid reduce? (only for images) - if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1: - axes = self.sts[0].unit_stride_axes() - assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}" - if self.sts[0].shape[axes[0]]%4 == 0: - self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4)) + # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) + # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below + # expression and run test/test_ops.py with IMAGE=2 + # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) + # this can be made much smarter + to_upcast: List[int] = [] + # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) + for axis in range(self.first_reduce): + # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent + # for now skip upcasting here if there is a symbolic axis + if ( + isinstance(self.full_shape[axis], int) + and self.full_shape[axis] <= 7 + and any(st.axis_is_masked(axis) for st in self.sts) + and prod(self.full_shape[self.shape_len - self.upcasted :]) + * prod(self.full_shape[j] for j in to_upcast) + * self.full_shape[axis] + <= 7 * 7 + ): + if DEBUG >= 4: + print(f"upcasting masked axis : {axis}") + to_upcast.append(axis) + for axis in to_upcast[::-1]: + self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) - # upcast float4 images - for buf_index,buf in enumerate(self.bufs): - unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] - if buf.dtype.__class__ is ImageDType: - #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" - if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: - if unit_stride_axes_mul_4[0] < self.first_reduce: - self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) - else: - self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) + # potentially do more upcasts of non reduce axes based on a heuristic + upcasted_axis = set() + while prod(self.sts[0].shape[: self.first_reduce]) >= 1024: + xb_choices = [] + for axis, upcast_amount in itertools.product( + range(self.first_reduce), [3, 4] + ): # consider all the non reduce axes, and a 3 or 4 reduce + # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already + if ( + axis not in upcasted_axis + and isinstance(self.full_shape[axis], int) + and self.full_shape[axis] % upcast_amount == 0 + and any( + st.views[-1].strides[axis] == 0 + and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) + for buf_index, st in enumerate(self.sts) + ) + ): + xb_choices.append( + ( + sum(st.views[-1].strides[axis] > 0 for st in self.sts), + sum(st.views[-1].strides[axis] for st in self.sts), + axis, + upcast_amount, + ) + ) + if xb_choices: + xb_choices = sorted(xb_choices) + if DEBUG >= 4: + print(f"float4 merging axis : {xb_choices}") + self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) + upcasted_axis.add(xb_choices[0][2]) + else: + break - # no more opt if we are grouping - if self.group_for_reduce: return + # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS + if ( + self.first_reduce < (self.shape_len - self.upcasted) + and ( + len(list(self.shape_offsets(self.full_buf_index))) <= 4 + or not any(r for _, _, r in self.upcasted_axis(self.full_buf_index)) + ) + and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted :]) < 64) + ): + if (s := self.full_unupcasted_shape[-1]) <= 32 and isinstance( + s, int + ): # NOTE: cannot loop unroll symbolic axis + self.apply_opt( + Opt( + OptOps.UNROLL, + len(self.full_unupcasted_shape) - 1 - self.first_reduce, + 0, + ) + ) + # if it's small, upcast a second reduce dimension too + if ( + self.first_reduce < (self.shape_len - self.upcasted) + and s <= 3 + and (s2 := self.full_unupcasted_shape[-1]) <= 3 + and isinstance(s2, int) + ): + self.apply_opt( + Opt( + OptOps.UNROLL, + len(self.full_unupcasted_shape) - 1 - self.first_reduce, + 0, + ) + ) + else: + for splits in [4]: + if self.full_unupcasted_shape[-1] % splits == 0: + self.apply_opt( + Opt( + OptOps.UNROLL, + len(self.full_unupcasted_shape) - 1 - self.first_reduce, + splits, + ) + ) + break - # **** below this line need to be optional and benchmarked **** - - # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) - # to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below - # expression and run test/test_ops.py with IMAGE=2 - # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) - # this can be made much smarter - to_upcast: List[int] = [] - # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) - for axis in range(self.first_reduce): - # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent - # for now skip upcasting here if there is a symbolic axis - if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \ - prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7: - if DEBUG >= 4: print(f"upcasting masked axis : {axis}") - to_upcast.append(axis) - for axis in to_upcast[::-1]: - self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) - - # potentially do more upcasts of non reduce axes based on a heuristic - upcasted_axis = set() - while prod(self.sts[0].shape[:self.first_reduce]) >= 1024: - xb_choices = [] - for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce - # if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) - if xb_choices: - xb_choices = sorted(xb_choices) - if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") - self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) - upcasted_axis.add(xb_choices[0][2]) - else: - break - - # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS - if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): - if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - # if it's small, upcast a second reduce dimension too - if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int): - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - else: + # if nothing at all is upcasted and it's easy to, do an upcast + # TODO: this is breaking the tests for splits in [4]: - if self.full_unupcasted_shape[-1]%splits == 0: - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits)) - break + if ( + self.upcasted == 0 + and self.full_unupcasted_shape + and self.full_unupcasted_shape[-1] % splits == 0 + ): + self.apply_opt( + Opt(OptOps.UPCAST, len(self.full_unupcasted_shape) - 1, splits) + ) - # if nothing at all is upcasted and it's easy to, do an upcast - # TODO: this is breaking the tests - for splits in [4]: - if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0: - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits)) + # **** local groups **** - # **** local groups **** - - if self.opts.has_local: - if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce: - self.apply_opt(Opt(OptOps.NOLOCALS)) - else: - # prioritize making expand axes local - local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] - to_local: List[Tuple[int, int]] = [] - for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): - local_size = prod(sz for _, sz in to_local) - local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) - if local_sz is not None: to_local.append((axis, local_sz)) - deleted_shape = 0 - for axis, local_sz in sorted(to_local[:3]): - axis = axis - deleted_shape - will_delete_shape = local_sz == self.full_shape[axis] - self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) - if will_delete_shape: deleted_shape += 1 + if self.opts.has_local: + if ( + getenv("NOLOCALS") + and self.local_dims == 0 + and not self.group_for_reduce + ): + self.apply_opt(Opt(OptOps.NOLOCALS)) + else: + # prioritize making expand axes local + local_axis_ranking = [ + ( + any( + self.sts[buf_index].views[-1].strides[axis] == 0 + for buf_index in range(len(self.sts)) + ), + axis, + ) + for axis in range(len(self.full_shape[: self.first_reduce])) + ] + to_local: List[Tuple[int, int]] = [] + for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): + local_size = prod(sz for _, sz in to_local) + local_sz: Optional[int] = next( + ( + x + for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) + if self.full_shape[axis] % x == 0 and local_size * x <= 128 + ), + None, + ) + if local_sz is not None: + to_local.append((axis, local_sz)) + deleted_shape = 0 + for axis, local_sz in sorted(to_local[:3]): + axis = axis - deleted_shape + will_delete_shape = local_sz == self.full_shape[axis] + self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) + if will_delete_shape: + deleted_shape += 1 diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index bbc08614f..32456dd96 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,524 +1,1201 @@ from __future__ import annotations -from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Sequence, Final, Set +from typing import ( + List, + Tuple, + Any, + Optional, + cast, + DefaultDict, + Dict, + Union, + Sequence, + Final, + Set, +) import itertools, math, functools from collections import defaultdict from enum import Enum, auto from dataclasses import dataclass -from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten -from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps +from tinygrad.helpers import ( + colored, + ImageDType, + DEBUG, + dtypes, + DType, + prod, + PtrDType, + getenv, + all_same, + to_function_name, + flatten, +) +from tinygrad.ops import ( + LazyOp, + UnaryOps, + BinaryOps, + TernaryOps, + ReduceOps, + ConstBuffer, + MemBuffer, + BufferOps, +) from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode +from tinygrad.shape.symbolic import ( + Variable, + NumNode, + VariableOrNum, + Node, + SumNode, + MulNode, + DivNode, + ModNode, + LtNode, + AndNode, +) from tinygrad.codegen.kernel import LocalBuffer, Kernel from tinygrad.lazy import vars_from_ast from tinygrad.features.image import to_image_idx + # bottom ones are asm only class UOps(Enum): - LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702 - DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 - LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702 - ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702 + LOOP = auto() + IF = auto() + END = auto() + SPECIAL = auto() # loops can be global, local, or other # noqa: E702 + DEFINE_GLOBAL = auto() + DEFINE_LOCAL = auto() + DEFINE_ACC = auto() # this defines buffers # noqa: E702 + LOAD = auto() + STORE = auto() + CONST = auto() + BARRIER = auto() + PHI = auto() # noqa: E702 + ALU = auto() + WMMA = auto() + CAST = auto() + GEP = auto() # noqa: E702 + @dataclass(eq=False) class UOp: - uop: UOps - dtype: Optional[DType] - vin: Tuple[UOp, ...] - arg: Any - def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" + uop: UOps + dtype: Optional[DType] + vin: Tuple[UOp, ...] + arg: Any + + def __repr__(self): + return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" + + +def get_grouped_dims(prefix, start_dim, local_dims, maxdim: int = 0): + local_idxs = loop_local_idxs = [ + Variable(f"{prefix}{start_dim+i}", 0, s - 1) + for i, s in enumerate( + local_dims[0 : maxdim - 1] + (prod(local_dims[maxdim - 1 :]),) + if len(local_dims) > maxdim + else local_dims + ) + ] + if maxdim != 0 and len(local_dims) > maxdim: + dd = local_idxs[maxdim - 1] + nli = [] + for s in local_dims[maxdim - 1 :][::-1]: + nli.append(dd % s) + dd //= s + local_idxs = local_idxs[0 : maxdim - 1] + nli[::-1] + return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)] -def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0): - local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] - if maxdim != 0 and len(local_dims) > maxdim: - dd = local_idxs[maxdim-1] - nli = [] - for s in local_dims[maxdim-1:][::-1]: - nli.append(dd % s) - dd //= s - local_idxs = local_idxs[0:maxdim-1] + nli[::-1] - return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)] class Linearizer(Kernel): - def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32): - render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) - return self.uop(UOps.ALU, dtype, (a, render_b), op) + def uop_alu_idx(self, a: UOp, b, ops, ctx: Linearizer, op, dtype=dtypes.int32): + render_b: UOp = cast( + UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx) + ) + return self.uop(UOps.ALU, dtype, (a, render_b), op) - # NOTE: the consts have to be cached for deduping of downstream uops to work - def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp: return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) - def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val + # NOTE: the consts have to be cached for deduping of downstream uops to work + def const( + self, b: Union[int, float], dtype=dtypes.int32, insert_before=None + ) -> UOp: + return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before) - def get_reduce_acc(self, op, dtype:DType): - if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 - elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False + def cast(self, val: UOp, dtype) -> UOp: + return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val - render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b), - MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), - DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV), - ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD), - LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } + def get_reduce_acc(self, op, dtype: DType): + if op == ReduceOps.SUM: + return 0.0 if dtypes.is_float(dtype) else 0 + elif op == ReduceOps.MAX: + return ( + -math.inf + if dtypes.is_float(dtype) + else -(2**31) + if dtypes.is_int(dtype) + else False + ) - def global_load(self, i:int, idxs:Sequence[Node], acc=None, barrier:Optional[UOp]=None) -> List[UOp]: - buf = self.bufs[i] - const = buf.val if isinstance(buf, ConstBuffer) else acc + render_ops: Any = { + Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], + NumNode: lambda self, ops, ctx: ctx.const(self.b), + MulNode: lambda self, ops, ctx: ctx.uop_alu_idx( + self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL + ), + DivNode: lambda self, ops, ctx: ctx.uop_alu_idx( + self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV + ), + ModNode: lambda self, ops, ctx: ctx.uop_alu_idx( + self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD + ), + LtNode: lambda self, ops, ctx: ctx.uop_alu_idx( + self.a.render(ops, ctx), + self.b, + ops, + ctx, + BinaryOps.CMPLT, + dtype=dtypes.bool, + ), + SumNode: lambda self, ops, ctx: functools.reduce( + lambda a, b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), + self.nodes[1:], + self.nodes[0].render(ops, ctx), + ), + AndNode: lambda self, ops, ctx: functools.reduce( + lambda a, b: ctx.uop_alu_idx( + a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool + ), + self.nodes[1:], + self.nodes[0].render(ops, ctx), + ), + } - def rename_var(v: VariableOrNum, expr: str): return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max) + def global_load( + self, i: int, idxs: Sequence[Node], acc=None, barrier: Optional[UOp] = None + ) -> List[UOp]: + buf = self.bufs[i] + const = buf.val if isinstance(buf, ConstBuffer) else acc - amt, dim = 1, None - upcast_dim = self.get_upcast_dim(i) - if len(upcast_dim) == 1 and len(float4_expand := idxs[upcast_dim[0]].expand()) in [4,2]: - dim, amt = upcast_dim[0], len(float4_expand) + def rename_var(v: VariableOrNum, expr: str): + return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max) - expand_vars = tuple([rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)]) - fake_idxs = [idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars)] - if dim is not None: - g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim+1:]) - if (g_idx // amt * amt).render() != g_idx.render(): - (g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None - else: - g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs) - localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt) - if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt) + amt, dim = 1, None + upcast_dim = self.get_upcast_dim(i) + if len(upcast_dim) == 1 and len( + float4_expand := idxs[upcast_dim[0]].expand() + ) in [4, 2]: + dim, amt = upcast_dim[0], len(float4_expand) - e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars) - - ret = [] - invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0 - for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)): - this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid) - key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" - if key not in self.load_cache: - if acc is not None: - self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const, cachable=False) - elif this_const is not None: - self.load_cache[key] = self.const(this_const, localtype) - if valid.min == 0 and valid.max == 1: - valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE) - elif isinstance(buf.dtype, ImageDType): - buf_uop = self.buf_uops[i] - assert buf_uop is not None, f"buffer {i} wasn't UOped" - image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), (image_idx[0].render(self.render_ops, self), image_idx[1].render(self.render_ops, self))) - valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, dtypes.float32.vec(4))) if valid.min == 0 else tuple() - self.load_cache[key] = self.uop(UOps.LOAD, dtypes.float32.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) - idx_small = idx%4 - res = idx_small.render(self.render_ops, self) - if localtype == localtype.scalar(): - out = self.uop(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max) - for ix in range(idx_small.max, idx_small.min, -1): - rvv = self.uop(UOps.GEP, localtype, (self.load_cache[key],), ix-1) - sel = self.uop(UOps.ALU, res.dtype, (res, self.const(ix)), BinaryOps.CMPLT) - out = self.uop(UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE) - self.load_cache[key] = out + expand_vars = tuple( + [rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)] + ) + fake_idxs = [ + idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars) + ] + if dim is not None: + g_idx, g_valid = self.sts[i].expr_idxs( + fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim + 1 :] + ) + if (g_idx // amt * amt).render() != g_idx.render(): + (g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None else: - buf_uop = self.buf_uops[i] - assert buf_uop is not None, f"buffer {i} wasn't UOped" - rendered_idx = idx.render(self.render_ops, self) - valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple() - self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) - ret.append(self.uop(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) - return ret + g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs) + localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt) + if isinstance(buf.dtype, ImageDType): + localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt) - def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]: - buf = self.bufs[i] - buf_uop = self.buf_uops[i] - assert buf_uop is not None, f"buffer {i} wasn't UOped" + e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars) - expanded_nodes = [idx.expand() for idx in idxs] - _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] - store_offset = dict(zip(_idxs, store)) + ret = [] + invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0 + for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)): + this_const, idx, valid = ( + (invalid_value, NumNode(0), NumNode(1)) + if valid.max == 0 + else (const, idx, valid) + ) + key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" + if key not in self.load_cache: + if acc is not None: + self.load_cache[key] = self.uop( + UOps.DEFINE_ACC, localtype, (), this_const, cachable=False + ) + elif this_const is not None: + self.load_cache[key] = self.const(this_const, localtype) + if valid.min == 0 and valid.max == 1: + valid_rendered = valid.render(self.render_ops, self) + self.load_cache[key] = self.uop( + UOps.ALU, + localtype, + ( + valid_rendered, + self.load_cache[key], + self.const(invalid_value, localtype), + ), + TernaryOps.WHERE, + ) + elif isinstance(buf.dtype, ImageDType): + buf_uop = self.buf_uops[i] + assert buf_uop is not None, f"buffer {i} wasn't UOped" + image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) + rendered_idx = self.uop( + UOps.CAST, + dtypes.int.vec(2), + ( + image_idx[0].render(self.render_ops, self), + image_idx[1].render(self.render_ops, self), + ), + ) + valid_tuple = ( + ( + valid.render(self.render_ops, self), + self.const(invalid_value, dtypes.float32.vec(4)), + ) + if valid.min == 0 + else tuple() + ) + self.load_cache[key] = self.uop( + UOps.LOAD, + dtypes.float32.vec(4), + (buf_uop, rendered_idx) + + valid_tuple + + ((barrier,) if barrier else ()), + ) + idx_small = idx % 4 + res = idx_small.render(self.render_ops, self) + if localtype == localtype.scalar(): + out = self.uop( + UOps.GEP, localtype, (self.load_cache[key],), idx_small.max + ) + for ix in range(idx_small.max, idx_small.min, -1): + rvv = self.uop( + UOps.GEP, localtype, (self.load_cache[key],), ix - 1 + ) + sel = self.uop( + UOps.ALU, + res.dtype, + (res, self.const(ix)), + BinaryOps.CMPLT, + ) + out = self.uop( + UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE + ) + self.load_cache[key] = out + else: + buf_uop = self.buf_uops[i] + assert buf_uop is not None, f"buffer {i} wasn't UOped" + rendered_idx = idx.render(self.render_ops, self) + valid_tuple = ( + ( + valid.render(self.render_ops, self), + self.const(invalid_value, localtype), + ) + if valid.min == 0 + else tuple() + ) + self.load_cache[key] = self.uop( + UOps.LOAD, + localtype, + (buf_uop, rendered_idx) + + valid_tuple + + ((barrier,) if barrier else ()), + ) + ret.append( + self.uop( + UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim] + ) + if dim is not None + else self.load_cache[key] + ) + return ret - # float4 grouping - upcast_dim = self.get_upcast_dim(i) - if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2,4]: - grouped_store_offset = defaultdict(list) - for k in store_offset: - _idx = k[:upcast_dim[0]] + (expanded_nodes[upcast_dim[0]][0],) + k[upcast_dim[0]+1:] - grouped_store_offset[_idx].append(store_offset[k]) - store_offset_new = {} - for k,out_tokens in grouped_store_offset.items(): - amt = len(out_tokens) - idx, valid = self.sts[i].expr_idxs(k) - assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned" - store_offset_new[k] = self.uop(UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens)) - store_offset = store_offset_new + def global_store(self, i: int, idxs: List[Node], store: List[UOp]) -> List[UOp]: + buf = self.bufs[i] + buf_uop = self.buf_uops[i] + assert buf_uop is not None, f"buffer {i} wasn't UOped" - stores = [] - for idx, var in store_offset.items(): - idx, valid = self.sts[i].expr_idxs(idx) - if isinstance(buf.dtype, ImageDType): - idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uop(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in idx)) - else: - rendered_idx = idx.render(self.render_ops, self) - if valid.min == 1: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))) - else: stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)))) - return stores + expanded_nodes = [idx.expand() for idx in idxs] + _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] + store_offset = dict(zip(_idxs, store)) - kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) - def linearize(self): - # no new opts and we already ran? skip relinearizing - if self.applied_opts == self.applied_opts_cache: return self + # float4 grouping + upcast_dim = self.get_upcast_dim(i) + if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2, 4]: + grouped_store_offset = defaultdict(list) + for k in store_offset: + _idx = ( + k[: upcast_dim[0]] + + (expanded_nodes[upcast_dim[0]][0],) + + k[upcast_dim[0] + 1 :] + ) + grouped_store_offset[_idx].append(store_offset[k]) + store_offset_new = {} + for k, out_tokens in grouped_store_offset.items(): + amt = len(out_tokens) + idx, valid = self.sts[i].expr_idxs(k) + assert ( + idx.render() == ((idx // amt) * amt).render() + ), "float4 stores are always aligned" + store_offset_new[k] = self.uop( + UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens) + ) + store_offset = store_offset_new - # save backups - sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduce[:], self.upcasted + stores = [] + for idx, var in store_offset.items(): + idx, valid = self.sts[i].expr_idxs(idx) + if isinstance(buf.dtype, ImageDType): + idx, valid = to_image_idx(buf.dtype.shape, idx, valid) + rendered_idx = self.uop( + UOps.CAST, + dtypes.int.vec(2), + tuple(x.render(self.render_ops, self) for x in idx), + ) + else: + rendered_idx = idx.render(self.render_ops, self) + if valid.min == 1: + stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))) + else: + stores.append( + self.uop( + UOps.STORE, + None, + ( + buf_uop, + rendered_idx, + var, + valid.render(self.render_ops, self), + ), + ) + ) + return stores - # global uop cache - self.saved_exprs: Dict[Tuple, UOp] = dict() + kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) - # limit dims if we need to - if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max) + def linearize(self): + # no new opts and we already ran? skip relinearizing + if self.applied_opts == self.applied_opts_cache: + return self - # uops - self.uops: List[UOp] = [] - self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs) - self.loop_uops: Dict[str, UOp] = {} + # save backups + sts_backup, gfr_backup, upc_backup = ( + self.sts[:], + self.group_for_reduce[:], + self.upcasted, + ) - # add global buffers - for i,buf in enumerate(self.bufs): - if isinstance(buf, MemBuffer): - self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) - # add var vals - for var in vars_from_ast(self.ast): - assert var.expr is not None - self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) - # define local buffers - for lb in self.local_alias.values(): - self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size())) - # add a local buffer for multistage reduce. # TODO: use local alias - if self.group_for_reduce: - # TODO: the strides of this can be controlled - self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+len(self.group_for_reduce)]) + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) - self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) - self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size()))) + # global uop cache + self.saved_exprs: Dict[Tuple, UOp] = dict() - # kernel name (before late upcast) - self.name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + # limit dims if we need to + if self.opts.global_max and self.opts.local_max: + self.limit_dims_to_max(self.opts.global_max, self.opts.local_max) - # name the function something unique - Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1 - suffix = f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" if Linearizer.kernel_cnt[function_name] > 1 else "" - self.name = self.name+colored(suffix, 'BLACK') + # uops + self.uops: List[UOp] = [] + self.buf_uops: List[Optional[UOp]] = [None] * len(self.bufs) + self.loop_uops: Dict[str, UOp] = {} - # define indexes - global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0) - local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+len(self.group_for_reduce)], 3 if self.opts.has_local else 0) - full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]] - upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] + # add global buffers + for i, buf in enumerate(self.bufs): + if isinstance(buf, MemBuffer): + self.buf_uops[i] = self.uop( + UOps.DEFINE_GLOBAL, + PtrDType(buf.dtype) + if not isinstance(buf.dtype, ImageDType) + else buf.dtype, + (), + (f"data{buf.idx}", buf.dtype), + ) + # add var vals + for var in vars_from_ast(self.ast): + assert var.expr is not None + self.loop_uops[var.expr] = self.uop( + UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32) + ) + # define local buffers + for lb in self.local_alias.values(): + self.buf_uops[self.bufs.index(lb)] = self.uop( + UOps.DEFINE_LOCAL, + PtrDType(dtypes.float32), + (), + (lb.name, self.sts[self.bufs.index(lb)].size()), + ) + # add a local buffer for multistage reduce. # TODO: use local alias + if self.group_for_reduce: + # TODO: the strides of this can be controlled + self.sts.append( + ShapeTracker.from_shape( + tuple( + [1] * self.global_dims + + list( + self.full_shape[ + self.global_dims : self.global_dims + + self.local_dims + + len(self.group_for_reduce) + ] + ) + + [1] + * ( + self.shape_len + - self.upcasted + - len(self.group_for_reduce) + - self.first_reduce + ) + + [x[0] for x in self.upcasted_axis(0)] + ) + ) + ) + self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) + self.buf_uops.append( + self.uop( + UOps.DEFINE_LOCAL, + PtrDType(dtypes.float32), + (), + ("temp", self.sts[-1].size()), + ) + ) - # global and local loops - def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]: - new_loops = {x.expr:self.uop(UOps.LOOP, dtypes.int32, ( - self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self), - self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} - self.loop_uops.update(new_loops) - return tuple(new_loops.values()) + # kernel name (before late upcast) + self.name = ("r_" if self.reduceop else "E_") + colored("_", "BLACK").join( + [colored(str(x), c) for x, c in zip(self.full_shape, self.colors())] + ) - # set global/local size - self.global_size: Optional[List[int]] = None - self.local_size: Optional[List[int]] = None - if self.dont_use_locals: - self.global_size = [x.max+1 for x in loop_global_idxs][::-1] - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) - elif self.opts.has_local: - self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1] - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) - self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) - else: - render_loop(loop_global_idxs+loop_local_idxs) + # name the function something unique + Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1 + suffix = ( + f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}" + if Linearizer.kernel_cnt[function_name] > 1 + else "" + ) + self.name = self.name + colored(suffix, "BLACK") - # parse AST - loaded_buffers = {} - acc: List[UOp] = [] - self.load_cache: Dict[str, UOp] = {} + # define indexes + global_idxs, loop_global_idxs = get_grouped_dims( + "gidx", + 0, + self.full_shape[: self.global_dims], + 3 if self.opts.has_local else 0, + ) + local_idxs, loop_local_idxs = get_grouped_dims( + "lidx", + self.global_dims, + self.full_shape[ + self.global_dims : self.first_reduce + len(self.group_for_reduce) + ], + 3 if self.opts.has_local else 0, + ) + full_upcast_idxs = [ + Variable(None, 0, s - 1) + for s in self.full_shape[self.shape_len - self.upcasted :] + ] + upcast_idxs = [ + Variable(None, 0, s - 1) + for s in self.output_shape[self.shape_len - self.upcasted :] + ] - # reduce op - fake_reduce_idxs: List[Variable] = [] - if self.reduceop is not None: - # define indexes - reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)] - fake_reduce_idxs = [x*0 for x in reduce_idxs] + # global and local loops + def render_loop(xx: List[Variable]) -> Tuple[UOp, ...]: + new_loops = { + x.expr: self.uop( + UOps.LOOP, + dtypes.int32, + ( + self.const(x.min) + if isinstance(x.min, int) + else cast(Node, x.min).render(self.render_ops, self), + self.const(x.max + 1) + if isinstance(x.max, int) + else cast(Node, x.max + 1).render(self.render_ops, self), + ), + cachable=False, + ) + for x in xx + if not isinstance(x, NumNode) and x.expr is not None + } + self.loop_uops.update(new_loops) + return tuple(new_loops.values()) - # define accumulator - acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype)) + # set global/local size + self.global_size: Optional[List[int]] = None + self.local_size: Optional[List[int]] = None + if self.dont_use_locals: + self.global_size = [x.max + 1 for x in loop_global_idxs][::-1] + self.loop_uops.update( + { + x.expr: self.uop( + UOps.SPECIAL, + dtypes.int32, + (), + ( + len(loop_global_idxs) - 1 - i, + x.expr.replace("gidx", "idx"), + x.max + 1, + ), + ) + for i, x in enumerate(loop_global_idxs) + } + ) + elif self.opts.has_local: + self.global_size, self.local_size = [x.max + 1 for x in loop_global_idxs][ + ::-1 + ], [x.max + 1 for x in loop_local_idxs][::-1] + self.loop_uops.update( + { + x.expr: self.uop( + UOps.SPECIAL, + dtypes.int32, + (), + (len(loop_global_idxs) - 1 - i, x.expr, x.max + 1), + ) + for i, x in enumerate(loop_global_idxs) + } + ) + self.loop_uops.update( + { + x.expr: self.uop( + UOps.SPECIAL, + dtypes.int32, + (), + (len(loop_local_idxs) - 1 - i, x.expr, x.max + 1), + ) + for i, x in enumerate(loop_local_idxs) + } + ) + else: + render_loop(loop_global_idxs + loop_local_idxs) - if self.tensor_core: - def calc_tc_idxs(local_size: int, aliases: List[List[int]]): - replace_idxs = [] - for alias in aliases: - full_var, full_var_sz = NumNode(0), 1 - if alias[0] != 0: - for i in alias: - next_var = local_idxs[-i] if i > 0 else Variable(None, 0, local_size-1) - full_var += next_var * full_var_sz - full_var_sz *= next_var.max+1 - replace_idxs.append(full_var) - return replace_idxs - replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2]) - for n in range(len(self.tensor_core.threads)): - local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals - for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)): - upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts + # parse AST + loaded_buffers = {} + acc: List[UOp] = [] + self.load_cache: Dict[str, UOp] = {} - # reduce loop - loop_ctx = render_loop(reduce_idxs) + # reduce op + fake_reduce_idxs: List[Variable] = [] + if self.reduceop is not None: + # define indexes + reduce_idxs = [ + Variable(f"ridx{i}", 0, self.full_shape[i] - 1) + for i in range( + self.first_reduce + len(self.group_for_reduce), + self.shape_len - self.upcasted, + ) + ] + fake_reduce_idxs = [x * 0 for x in reduce_idxs] - # barrier for fast GEMM - if self.tensor_core: self.uop(UOps.BARRIER, None, (), cachable=False) + # define accumulator + acc = self.global_load( + 0, + global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, + self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype), + ) - # compute local aliases - locals_to_store = [] - for i in self.local_alias: - localbuf_idx = self.bufs.index(self.local_alias[i]) - buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())] - if self.tensor_core: - min_alias_idx = min(self.local_alias.keys()) - replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) - for n in range(len(self.tensor_core.threads)): - buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals - for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)): - buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts - if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs) - ll = self.global_load(i, buf_idxs) - locals_to_store.append((localbuf_idx, buf_idxs, ll)) + if self.tensor_core: - # copy in any global buffers - if self.tensor_core: - wmma_sz = self.tensor_core.thread_local_sizes - # calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else - nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2]) - acc_reds = math.isqrt((nx*ny)//nacc) - i, bx, by = 0, nx//acc_reds, ny//acc_reds - for y in range(by): - for x in range(bx): - for j in range(acc_reds): - op1, op2, op3 = locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]], locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]], acc[i:i+wmma_sz[2]] - if self.opts.device != "HIP": - ops = tuple(op1+op2+op3) - else: - ops = (self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op1)), - self.uop(UOps.CAST, dtypes.half.vec(16), tuple(op2)), - self.uop(UOps.CAST, dtypes.float.vec(8), tuple(op3))) - ret = self.uop(UOps.WMMA, dtypes.float.vec(2) if wmma_sz[2] == 2 else dtypes.float.vec(8), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,)) - for z in range(cast(DType, ret.dtype).sz): - acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx) - i += wmma_sz[2] - else: - if locals_to_store: - self.uop(UOps.BARRIER, None, (), cachable=False) - for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll) - self.uop(UOps.BARRIER, None, (), cachable=False) + def calc_tc_idxs(local_size: int, aliases: List[List[int]]): + replace_idxs = [] + for alias in aliases: + full_var, full_var_sz = NumNode(0), 1 + if alias[0] != 0: + for i in alias: + next_var = ( + local_idxs[-i] + if i > 0 + else Variable(None, 0, local_size - 1) + ) + full_var += next_var * full_var_sz + full_var_sz *= next_var.max + 1 + replace_idxs.append(full_var) + return replace_idxs - # load earlybufs - loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) + replace_acc_idxs = calc_tc_idxs( + self.tensor_core.thread_local_sizes[2], + self.tensor_core.thread_local_aliases[2], + ) + for n in range(len(self.tensor_core.threads)): + local_idxs[ + self.local_dims - len(self.tensor_core.threads) + n + ] = replace_acc_idxs[ + n + ] # replace locals + for n in range(len(replace_acc_idxs) - len(self.tensor_core.threads)): + upcast_idxs[n] = replace_acc_idxs[ + len(self.tensor_core.threads) + n + ] # replace upcasts - # run early AST (with reduce) - self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) + # reduce loop + loop_ctx = render_loop(reduce_idxs) - # end the reduce loop - self.load_cache.clear() + # barrier for fast GEMM + if self.tensor_core: + self.uop(UOps.BARRIER, None, (), cachable=False) - # end the local loop, do the local reduce - if self.group_for_reduce: - fake_global_idxs = [x*0 for x in global_idxs] - stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators - barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False) - if self.opts.has_local: - fake_idxs = [NumNode(0)]*len(self.sts[-1].shape) - fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] - if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self) - barrier = self.uop(UOps.IF, None, (if_cond, barrier), cachable=False) + # compute local aliases + locals_to_store = [] + for i in self.local_alias: + localbuf_idx = self.bufs.index(self.local_alias[i]) + buf_idxs = [ + idx * 0 if s == 0 else idx + for idx, s in zip( + global_idxs + local_idxs + reduce_idxs + full_upcast_idxs, + self.sts[i].real_strides(), + ) + ] + if self.tensor_core: + min_alias_idx = min(self.local_alias.keys()) + replace_input_idxs = calc_tc_idxs( + self.tensor_core.thread_local_sizes[i - min_alias_idx], + self.tensor_core.thread_local_aliases[i - min_alias_idx], + ) + for n in range(len(self.tensor_core.threads)): + buf_idxs[ + self.first_reduce - len(self.tensor_core.threads) + n + ] = replace_input_idxs[ + n + ] # replace locals + for n in range( + len(replace_input_idxs) - len(self.tensor_core.threads) + ): + buf_idxs[ + self.shape_len - self.upcasted + n + ] = replace_input_idxs[ + len(self.tensor_core.threads) + n + ] # replace upcasts + if DEBUG >= 3: + print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs) + ll = self.global_load(i, buf_idxs) + locals_to_store.append((localbuf_idx, buf_idxs, ll)) - # create new late reduce local loops and replace local_idxs that have been used - end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))] - local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:] + # copy in any global buffers + if self.tensor_core: + wmma_sz = self.tensor_core.thread_local_sizes + # calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else + nx, ny, nacc = ( + (len(locals_to_store[0][2]) // wmma_sz[0]), + (len(locals_to_store[1][2]) // wmma_sz[1]), + (len(acc) // wmma_sz[2]), + ) + acc_reds = math.isqrt((nx * ny) // nacc) + i, bx, by = 0, nx // acc_reds, ny // acc_reds + for y in range(by): + for x in range(bx): + for j in range(acc_reds): + op1, op2, op3 = ( + locals_to_store[0][2][ + (x + (j * bx)) + * wmma_sz[0] : (x + (j * bx) + 1) + * wmma_sz[0] + ], + locals_to_store[1][2][ + (y + (j * by)) + * wmma_sz[1] : (y + (j * by) + 1) + * wmma_sz[1] + ], + acc[i : i + wmma_sz[2]], + ) + if self.opts.device != "HIP": + ops = tuple(op1 + op2 + op3) + else: + ops = ( + self.uop( + UOps.CAST, dtypes.half.vec(16), tuple(op1) + ), + self.uop( + UOps.CAST, dtypes.half.vec(16), tuple(op2) + ), + self.uop( + UOps.CAST, dtypes.float.vec(8), tuple(op3) + ), + ) + ret = self.uop( + UOps.WMMA, + dtypes.float.vec(2) + if wmma_sz[2] == 2 + else dtypes.float.vec(8), + ops, + ( + self.opts.device, + self.tensor_core.dtype_in, + self.tensor_core.dtype_out, + ), + ) + for z in range(cast(DType, ret.dtype).sz): + acc[i + z] = self.uop( + UOps.PHI, + dtypes.float, + ( + op3[z], + self.uop(UOps.GEP, dtypes.float, (ret,), z), + ) + + loop_ctx, + ) + i += wmma_sz[2] + else: + if locals_to_store: + self.uop(UOps.BARRIER, None, (), cachable=False) + for i, idxs, ll in locals_to_store: + self.global_store(i, idxs, ll) + self.uop(UOps.BARRIER, None, (), cachable=False) - # if any group_for_reduce items aren't reduces, upcast them here - for j in self.upcast_in_mid_reduce_axes: - self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j]) - self.upcast() - self.group_for_reduce.pop() - local_idxs = local_idxs[:-1] - end_local_idxs = end_local_idxs[:-1] - # regenerate upcast_idxs - upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] + # load earlybufs + loaded_buffers.update( + { + b: self.global_load( + self.bufs.index(self.local_alias[i]) + if i in self.local_alias + else i, + global_idxs + local_idxs + reduce_idxs + full_upcast_idxs, + ) + for i, b in enumerate(self.bufs[1:], start=1) + if b in self.earlybufs + } + ) - # NOTE: this structure is the same as the reduce op above + # run early AST (with reduce) + self.ast_parse( + self.reduceop, + acc, + self.acc_offsets(self.full_buf_index), + loaded_buffers, + do_reduce=True, + loop_ctx=loop_ctx, + ) - # define late accumulator - acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) + # end the reduce loop + self.load_cache.clear() - # late reduce loop - loop_ctx = render_loop(end_local_idxs) + # end the local loop, do the local reduce + if self.group_for_reduce: + fake_global_idxs = [x * 0 for x in global_idxs] + stores = self.global_store( + -1, + fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, + acc, + ) # store accumulators + barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False) + if self.opts.has_local: + fake_idxs = [NumNode(0)] * len(self.sts[-1].shape) + fake_idxs[ + self.global_dims + + self.local_dims : self.global_dims + + len(local_idxs) + ] = local_idxs[self.local_dims :] + if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0] < 1).render( + self.render_ops, self + ) + barrier = self.uop( + UOps.IF, None, (if_cond, barrier), cachable=False + ) - # load localbufs - loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier) + # create new late reduce local loops and replace local_idxs that have been used + end_local_idxs = [ + Variable( + f"tidx{i}", + 0, + self.full_shape[i] - 1 + if i >= self.first_reduce + and i not in self.upcast_in_mid_reduce_axes + else 0, + ) + for i in range(0, self.first_reduce + len(self.group_for_reduce)) + ] + local_idxs = ( + local_idxs[: self.local_dims] + + end_local_idxs[self.global_dims + self.local_dims :] + ) - # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) + # if any group_for_reduce items aren't reduces, upcast them here + for j in self.upcast_in_mid_reduce_axes: + self.reshape_and_permute( + None, [i for i in range(self.shape_len) if i != j] + [j] + ) + self.upcast() + self.group_for_reduce.pop() + local_idxs = local_idxs[:-1] + end_local_idxs = end_local_idxs[:-1] + # regenerate upcast_idxs + upcast_idxs = [ + Variable(None, 0, s - 1) + for s in self.output_shape[self.shape_len - self.upcasted :] + ] - # end the late reduce loop - self.load_cache.clear() + # NOTE: this structure is the same as the reduce op above - # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) + # define late accumulator + acc = self.global_load( + -1, + fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, + self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype), + ) - # run late AST (without the store) - val = self.ast_parse(cast(LazyOp, self.ast.src[0]), acc, None, loaded_buffers) + # late reduce loop + loop_ctx = render_loop(end_local_idxs) - # store - self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) + # load localbufs + loaded_buffers[self.bufs[-1]] = self.global_load( + -1, + fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, + barrier=barrier, + ) - # graph helper functions - @functools.lru_cache(None) - def get_recursive_parents(x:UOp) -> Set[UOp]: return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin]) + # there's no AST here (and there's no shape for the reduce LazyOp) + self.ast_parse( + LazyOp( + self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),) + ), + acc, + self.acc_offsets(-1), + loaded_buffers, + do_reduce=True, + loop_ctx=loop_ctx, + ) - def get_recursive_children(x:UOp) -> Set[UOp]: - deps = set([x]) - ssize = 0 - while ssize != len(deps): - ssize = len(deps) + # end the late reduce loop + self.load_cache.clear() + + # load latebufs + loaded_buffers.update( + { + b: self.global_load( + i, global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs + ) + for i, b in enumerate(self.bufs) + if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer + } + ) + + # run late AST (without the store) + val = self.ast_parse(cast(LazyOp, self.ast.src[0]), acc, None, loaded_buffers) + + # store + self.global_store( + 0, global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, val + ) + + # graph helper functions + @functools.lru_cache(None) + def get_recursive_parents(x: UOp) -> Set[UOp]: + return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin]) + + def get_recursive_children(x: UOp) -> Set[UOp]: + deps = set([x]) + ssize = 0 + while ssize != len(deps): + ssize = len(deps) + for u in self.uops: + if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): + deps.add(u) + return deps + + def replace_op(old: UOp, new: UOp): + for u in self.uops: + u.vin = tuple(new if x is old else x for x in u.vin) + self.uops.remove(old) + + # fix loop scope, push CONST and ALU upward out of loop if it does not depend on the loop + loop_stack: List[List[UOp]] = [[]] for u in self.uops: - if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): - deps.add(u) - return deps + if not loop_stack[-1]: + loop_stack[-1].append(u) + elif u.uop == UOps.LOOP: + loop_stack.append([u]) + elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST]: + loop_stack[-1].append(u) + else: + parents = get_recursive_parents(u) + for i in reversed(range(len(loop_stack))): + # check backwards and put the uop in the first encounter with some dependency + if any(x in parents for x in loop_stack[i]) or i == 0: + loop_stack[i].append(u) + break + self.uops = flatten(loop_stack) - def replace_op(old:UOp, new:UOp): - for u in self.uops: - u.vin = tuple(new if x is old else x for x in u.vin) - self.uops.remove(old) + # uops optimization + changed_something = True + while changed_something: + changed_something = False + for u in self.uops: + if u.uop == UOps.PHI and len(u.vin) == 3: + # if the parents of the PHI node don't have the LOOP in their parents, it can be folded + # TODO: ADD becomes a MUL, MAX can just become nothing + if ( + all( + x.uop != UOps.LOOP + for x in get_recursive_parents( + UOp(u.uop, u.dtype, u.vin[0:2], u.arg) + ) + ) + and u.vin[1].arg == BinaryOps.ADD + ): + if DEBUG >= 4: + print(f"removing PHI node {u}") + del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] + # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype + loop_len = self.uop( + UOps.ALU, + u.vin[2].vin[1].dtype, + (u.vin[2].vin[1], u.vin[2].vin[0]), + BinaryOps.SUB, + insert_before=self.uops.index(u), + ) + if loop_len.dtype != u.dtype: + loop_len = self.uop( + UOps.CAST, + u.dtype, + (loop_len,), + insert_before=self.uops.index(u), + ) + replace_op( + u, + self.uop( + UOps.ALU, + u.dtype, + ( + u.vin[1], + loop_len, + ), + BinaryOps.MUL, + insert_before=self.uops.index(u), + ), + ) + changed_something = True - # fix loop scope, push CONST and ALU upward out of loop if it does not depend on the loop - loop_stack: List[List[UOp]] = [[]] - for u in self.uops: - if not loop_stack[-1]: loop_stack[-1].append(u) - elif u.uop == UOps.LOOP: loop_stack.append([u]) - elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST]: loop_stack[-1].append(u) - else: - parents = get_recursive_parents(u) - for i in reversed(range(len(loop_stack))): - # check backwards and put the uop in the first encounter with some dependency - if any(x in parents for x in loop_stack[i]) or i == 0: - loop_stack[i].append(u) - break - self.uops = flatten(loop_stack) + # (recursively) remove childless uops + # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that + UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL} + while 1: + has_child: Set[UOp] = set() + for ru in self.uops: + for vu in ru.vin: + has_child.add(vu) + nu: List[UOp] = [ + x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS + ] + if len(nu) == len(self.uops): + break + if DEBUG >= 4: + print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") + self.uops = nu + del nu - # uops optimization - changed_something = True - while changed_something: - changed_something = False - for u in self.uops: - if u.uop == UOps.PHI and len(u.vin) == 3: - # if the parents of the PHI node don't have the LOOP in their parents, it can be folded - # TODO: ADD becomes a MUL, MAX can just become nothing - if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD: - if DEBUG >= 4: print(f"removing PHI node {u}") - del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] - # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype - loop_len = self.uop(UOps.ALU, u.vin[2].vin[1].dtype, (u.vin[2].vin[1], u.vin[2].vin[0]), BinaryOps.SUB, insert_before=self.uops.index(u)) - if loop_len.dtype != u.dtype: loop_len = self.uop(UOps.CAST, u.dtype, (loop_len,), insert_before=self.uops.index(u)) - replace_op(u, self.uop(UOps.ALU, u.dtype, (u.vin[1], loop_len,), BinaryOps.MUL, insert_before=self.uops.index(u))) - changed_something = True + # add UOps.END + for u in self.uops: + if u.uop == UOps.LOOP: + # add END of loops after the last thing that (recursively) depends on them + self.uop( + UOps.END, + None, + (u,), + cachable=False, + insert_before=self.uops.index( + sorted(list(get_recursive_children(u)), key=self.uops.index)[-1] + ) + + 1, + ) + elif u.uop == UOps.IF: + # END any if statements at the end of the uops + self.uop(UOps.END, None, (u,), cachable=False) - # (recursively) remove childless uops - # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that - UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL} - while 1: - has_child: Set[UOp] = set() - for ru in self.uops: - for vu in ru.vin: - has_child.add(vu) - nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS] - if len(nu) == len(self.uops): break - if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") - self.uops = nu - del nu + # maybe graph the uops + if DEBUG >= 5: + for u in self.uops: + print( + f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}" + ) + if getenv("GRAPHUOPS"): + from tinygrad.graph import graph_uops - # add UOps.END - for u in self.uops: - if u.uop == UOps.LOOP: - # add END of loops after the last thing that (recursively) depends on them - self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) - elif u.uop == UOps.IF: - # END any if statements at the end of the uops - self.uop(UOps.END, None, (u,), cachable=False) + graph_uops(self.uops) - # maybe graph the uops - if DEBUG >= 5: - for u in self.uops: - print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") - if getenv("GRAPHUOPS"): - from tinygrad.graph import graph_uops - graph_uops(self.uops) + # restore backups + self.sts, self.group_for_reduce, self.upcasted = ( + sts_backup, + gfr_backup, + upc_backup, + ) - # restore backups - self.sts, self.group_for_reduce, self.upcasted = sts_backup, gfr_backup, upc_backup + # set cache and return + self.applied_opts_cache = self.applied_opts[:] + return self - # set cache and return - self.applied_opts_cache = self.applied_opts[:] - return self + def uop( + self, + uop: UOps, + dtype: Optional[DType] = None, + vin: Tuple[UOp, ...] = tuple(), + arg: Any = None, + cachable=True, + insert_before=None, + simplify=True, + ) -> UOp: + key = (uop, dtype, vin, arg) + if uop == UOps.PHI and vin[1].dtype != dtype: + vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:] + if uop == UOps.ALU: # upcast vins to the same dtype + upcast_dtype = ( + dtypes.float + if arg == TernaryOps.MULACC + else max(cast(DType, x.dtype) for x in vin) + ) # MULACC is only supported in float + if arg == TernaryOps.WHERE: + vin = (vin[0],) + tuple( + self.cast(x, upcast_dtype) for x in vin[1:] + ) # the first arg is always bool + else: + vin = tuple(self.cast(x, upcast_dtype) for x in vin) + dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool + if simplify: + if uop == UOps.PHI and len(vin) == 2: + return vin[1] # a phi without loops is a noop + if uop == UOps.GEP and vin[0].uop == UOps.CONST: + return self.const(vin[0].arg, dtype, insert_before) + if ( + uop == UOps.CAST + and all(x.uop == UOps.CONST for x in vin) + and all_same([x.arg for x in vin]) + ): + return self.const(vin[0].arg, dtype, insert_before) + if uop == UOps.ALU: + # rewrites. NOTE: the rewritten NEG op is still around... + if ( + arg == BinaryOps.ADD + and vin[1].uop == UOps.ALU + and vin[1].arg == UnaryOps.NEG + ): + return self.uop( + UOps.ALU, + dtype, + (vin[0], vin[1].vin[0]), + BinaryOps.SUB, + cachable=cachable, + insert_before=insert_before, + ) + # constant folding + if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: + return self.const(-vin[0].arg, dtype, insert_before) + if arg == TernaryOps.WHERE and vin[1] == vin[2]: + return vin[ + 1 + ] # a conditional with the same results either way is a noop + # zero folding + for x in [0, 1]: + if ( + arg == BinaryOps.ADD + and vin[x].uop == UOps.CONST + and vin[x].arg == 0.0 + ): + return vin[1 - x] + if ( + arg == BinaryOps.MUL + and vin[x].uop == UOps.CONST + and vin[x].arg == 1.0 + ): + return vin[1 - x] + if ( + arg == BinaryOps.MUL + and vin[x].uop == UOps.CONST + and vin[x].arg == 0.0 + ): + return vin[x] + if ( + arg == BinaryOps.SUB + and vin[1].uop == UOps.CONST + and vin[1].arg == 0.0 + ): + return vin[0] + if ( + arg == BinaryOps.DIV + and vin[1].uop == UOps.CONST + and vin[1].arg == 1.0 + ): + return vin[0] - def uop(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: - key = (uop, dtype, vin, arg) - if uop == UOps.PHI and vin[1].dtype != dtype: vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:] - if uop == UOps.ALU: # upcast vins to the same dtype - upcast_dtype = dtypes.float if arg == TernaryOps.MULACC else max(cast(DType, x.dtype) for x in vin) # MULACC is only supported in float - if arg == TernaryOps.WHERE: vin = (vin[0],) + tuple(self.cast(x, upcast_dtype) for x in vin[1:]) # the first arg is always bool - else: vin = tuple(self.cast(x, upcast_dtype) for x in vin) - dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool - if simplify: - if uop == UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop - if uop == UOps.GEP and vin[0].uop == UOps.CONST: return self.const(vin[0].arg, dtype, insert_before) - if uop == UOps.CAST and all(x.uop == UOps.CONST for x in vin) and all_same([x.arg for x in vin]): return self.const(vin[0].arg, dtype, insert_before) - if uop == UOps.ALU: - # rewrites. NOTE: the rewritten NEG op is still around... - if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable, insert_before=insert_before) - # constant folding - if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.const(-vin[0].arg, dtype, insert_before) - if arg == TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop - # zero folding - for x in [0,1]: - if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x] - if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x] - if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x] - if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0] - if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] + # When insert_before is set, need to check if the cached expr is valid with the given insert place. + if ( + cachable + and (expr := self.saved_exprs.get(key, None)) is not None + and (insert_before is None or self.uops.index(expr) <= insert_before) + ): + return expr + ret = UOp(uop, dtype, vin, arg) + if insert_before is not None: + self.uops.insert(insert_before, ret) + else: + self.uops.append(ret) + if cachable: + self.saved_exprs[key] = ret + return ret - # When insert_before is set, need to check if the cached expr is valid with the given insert place. - if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and (insert_before is None or self.uops.index(expr) <= insert_before): return expr - ret = UOp(uop, dtype, vin, arg) - if insert_before is not None: - self.uops.insert(insert_before, ret) - else: - self.uops.append(ret) - if cachable: self.saved_exprs[key] = ret - return ret - - def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]: - if x.op in BufferOps: return loaded_buffers[x.arg] - if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)] - if x.op in ReduceOps and not do_reduce: - assert offs is None, "not available if we aren't doing reduce" - return acc - # MULACC fusion. TODO: this is copied from Interpreted - if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL: - x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) - if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: - x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg) - values = [self.ast_parse(cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src] - ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} - if x.op in ops: - ret: List[UOp] = [] - input_acc = acc[:] - for val, off in zip(zip(*values), cast(List[int], offs)): - acc[off] = self.uop(UOps.ALU, vin=val+(acc[off],), arg=ops[x.op]) - ret.append(acc[off]) - for off in range(len(acc)): - if input_acc[off] != acc[off]: - acc[off] = self.uop(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx)) - else: - ret = [self.uop(UOps.ALU, dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, vin=val, arg=x.op) for val in zip(*values)] - return ret + def ast_parse( + self, + x: LazyOp, + acc: List[UOp], + offs: Optional[List[int]], + loaded_buffers: Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], + do_reduce=False, + loop_ctx=tuple(), + ) -> List[UOp]: + if x.op in BufferOps: + return loaded_buffers[x.arg] + if x.op == UnaryOps.CAST: + return [ + self.uop(UOps.CAST, x.arg[0], (u,), x.arg) + if not isinstance(x.arg[0], ImageDType) + else u + for u in self.ast_parse( + cast(LazyOp, x.src[0]), acc, offs, loaded_buffers + ) + ] + if x.op in ReduceOps and not do_reduce: + assert offs is None, "not available if we aren't doing reduce" + return acc + # MULACC fusion. TODO: this is copied from Interpreted + if ( + x.op == ReduceOps.SUM + and x.src[0].__class__ is LazyOp + and x.src[0].op == BinaryOps.MUL + ): + x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg) + if ( + x.op == ReduceOps.SUM + and x.src[0].__class__ is LazyOp + and x.src[0].op == UnaryOps.CAST + and x.src[0].src[0].__class__ is LazyOp + and x.src[0].src[0].op == BinaryOps.MUL + ): + x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg) + values = [ + self.ast_parse( + cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx + ) + for v in x.src + ] + ops = { + ReduceOps.SUM: BinaryOps.ADD, + ReduceOps.MAX: BinaryOps.MAX, + TernaryOps.MULACC: TernaryOps.MULACC, + } + if x.op in ops: + ret: List[UOp] = [] + input_acc = acc[:] + for val, off in zip(zip(*values), cast(List[int], offs)): + acc[off] = self.uop(UOps.ALU, vin=val + (acc[off],), arg=ops[x.op]) + ret.append(acc[off]) + for off in range(len(acc)): + if input_acc[off] != acc[off]: + acc[off] = self.uop( + UOps.PHI, + input_acc[off].dtype, + (input_acc[off], acc[off]) + tuple(loop_ctx), + ) + else: + ret = [ + self.uop( + UOps.ALU, + dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None, + vin=val, + arg=x.op, + ) + for val in zip(*values) + ] + return ret diff --git a/tinygrad/device.py b/tinygrad/device.py index 8494eb28d..5d2db5e5d 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,304 +3,643 @@ import numpy as np from collections import defaultdict from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable import importlib, inspect, functools, pathlib, time, re, ctypes -from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name, DType, from_mv, dtypes, flat_mv, ImageDType, round_up +from tinygrad.helpers import ( + ansilen, + DEBUG, + getenv, + GlobalCounters, + colored, + BEAM, + NOOPT, + all_int, + to_function_name, + DType, + from_mv, + dtypes, + flat_mv, + ImageDType, + round_up, +) from tinygrad.shape.symbolic import Variable, sym_infer, sint -from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op +from tinygrad.ops import ( + LazyOp, + TernaryOps, + get_lazyop_info, + ReduceOps, + BufferOps, + BinaryOps, + UnaryOps, + Op, +) if TYPE_CHECKING: - from tinygrad.codegen.linearizer import Linearizer - from tinygrad.codegen.kernel import LinearizerOptions + from tinygrad.codegen.linearizer import Linearizer + from tinygrad.codegen.kernel import LinearizerOptions # **************** Device **************** + class _Device: - def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] - def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT - def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix)) - @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none - def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]: - x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._buffers][0] - if isinstance(ret, type): ret = ret(ix) - return ret - @functools.cached_property - def DEFAULT(self) -> str: - device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore - if device_from_env: return device_from_env - for device in ["METAL", "CUDA", "GPU"]: - try: - if self[device]: return device - except Exception: pass - return "CPU" + def __init__(self) -> None: + self._buffers: List[str] = [ + x.stem[len("ops_") :].upper() + for x in (pathlib.Path(__file__).parent / "runtime").iterdir() + if x.stem.startswith("ops_") + ] + + def canonicalize(self, device: Optional[str]) -> str: + return ( + ( + device.split(":", 1)[0].upper() + + ((":" + device.split(":", 1)[1]) if ":" in device else "") + ).replace(":0", "") + if device is not None + else self.DEFAULT + ) + + def __getitem__(self, ix: str) -> Union[Interpreted, Compiled]: + return self.__get_canonicalized_item(self.canonicalize(ix)) + + @functools.lru_cache( + maxsize=None + ) # this class is a singleton, pylint: disable=method-cache-max-size-none + def __get_canonicalized_item(self, ix: str) -> Union[Interpreted, Compiled]: + x = ix.split(":")[0].upper() + ret = [ + cls + for cname, cls in inspect.getmembers( + importlib.import_module(f"tinygrad.runtime.ops_{x.lower()}") + ) + if (cname.lower() == x.lower() + "device") and x in self._buffers + ][0] + if isinstance(ret, type): + ret = ret(ix) + return ret + + @functools.cached_property + def DEFAULT(self) -> str: + device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore + if device_from_env: + return device_from_env + for device in ["METAL", "CUDA", "GPU"]: + try: + if self[device]: + return device + except Exception: + pass + return "CPU" + + Device = _Device() # **************** base Runner + helpers **************** -class JITRunner: - def __init__(self): - self.op_estimate, self.mem_estimate = 0, 0 - def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: - var_vals = var_vals if var_vals is not None else {} - from tinygrad.jit import CacheCollector - et = self(rawbufs, var_vals) - CacheCollector.add(self, rawbufs, var_vals) - return et - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - raise NotImplementedError("override this") -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None): - if var_vals is None: var_vals = {} - op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) - GlobalCounters.kernel_count += num_kernels - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate - if et is not None: GlobalCounters.time_sum_s += et - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) +class JITRunner: + def __init__(self): + self.op_estimate, self.mem_estimate = 0, 0 + + def exec( + self, rawbufs: List[Buffer], var_vals: Optional[Dict[Variable, int]] = None + ) -> Optional[float]: + var_vals = var_vals if var_vals is not None else {} + from tinygrad.jit import CacheCollector + + et = self(rawbufs, var_vals) + CacheCollector.add(self, rawbufs, var_vals) + return et + + def __call__( + self, + rawbufs: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ) -> Optional[float]: + raise NotImplementedError("override this") + + +def update_stats( + name: str, + op_estimate: sint, + mem_estimate: sint, + var_vals: Optional[Dict[Variable, int]], + et: Optional[float], + buf_count, + jit=False, + num_kernels=1, + lra: Optional[Dict] = None, +): + if var_vals is None: + var_vals = {} + op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer( + mem_estimate, var_vals + ) + GlobalCounters.kernel_count += num_kernels + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate + if et is not None: + GlobalCounters.time_sum_s += et + if DEBUG >= 2: + print( + f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + ( + str() + if et is None + else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)" + ) + ) + # **************** Buffer / Allocator **************** + class Buffer: - def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): - assert isinstance(dtype, DType) - self.device, self.size, self.dtype = device, size, dtype - self.allocator = Device[self.device].allocator - # TODO: image hack shouldn't be here. where should it be? - if isinstance(dtype, ImageDType) and hasattr(self.allocator, "_cast_image"): - assert opaque is None - row_pitch_items = round_up(dtype.shape[1], 256) * 4 - self.size = row_pitch_items * dtype.shape[0] # adjust the size to include the image padding - self._real_buf = self.allocator.alloc(self.size * dtype.itemsize) - self._buf = self.allocator._cast_image(self._real_buf, dtype, row_pitch_items * dtype.itemsize) - else: - self._buf = opaque if opaque is not None else self.allocator.alloc(size * dtype.itemsize) - # TODO: mem_used for all devices - if self.device == Device.DEFAULT: GlobalCounters.mem_used += self.size * self.dtype.itemsize - def __del__(self): - if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize - if isinstance(self.dtype, ImageDType): - self.allocator._free(self._buf) - self.allocator.free(self._real_buf, self.size * self.dtype.itemsize) - else: - self.allocator.free(self._buf, self.size * self.dtype.itemsize) - def __repr__(self): return f"" - def copyin(self, mv:memoryview): - mv = flat_mv(mv) - assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" - self.allocator.copyin(self._buf, mv) - return self - @staticmethod - def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data) - def toCPU(self) -> np.ndarray: - # zero copy with as_buffer - if hasattr(self.allocator, 'as_buffer'): return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore - ret = np.empty(self.size, self.dtype.np) - if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf) - return ret + def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None): + assert isinstance(dtype, DType) + self.device, self.size, self.dtype = device, size, dtype + self.allocator = Device[self.device].allocator + # TODO: image hack shouldn't be here. where should it be? + if isinstance(dtype, ImageDType) and hasattr(self.allocator, "_cast_image"): + assert opaque is None + row_pitch_items = round_up(dtype.shape[1], 256) * 4 + self.size = ( + row_pitch_items * dtype.shape[0] + ) # adjust the size to include the image padding + self._real_buf = self.allocator.alloc(self.size * dtype.itemsize) + self._buf = self.allocator._cast_image( + self._real_buf, dtype, row_pitch_items * dtype.itemsize + ) + else: + self._buf = ( + opaque + if opaque is not None + else self.allocator.alloc(size * dtype.itemsize) + ) + # TODO: mem_used for all devices + if self.device == Device.DEFAULT: + GlobalCounters.mem_used += self.size * self.dtype.itemsize + + def __del__(self): + if self.device == Device.DEFAULT: + GlobalCounters.mem_used -= self.size * self.dtype.itemsize + if isinstance(self.dtype, ImageDType): + self.allocator._free(self._buf) + self.allocator.free(self._real_buf, self.size * self.dtype.itemsize) + else: + self.allocator.free(self._buf, self.size * self.dtype.itemsize) + + def __repr__(self): + return f"" + + def copyin(self, mv: memoryview): + mv = flat_mv(mv) + assert ( + len(mv) == self.size * self.dtype.itemsize + ), f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" + self.allocator.copyin(self._buf, mv) + return self + + @staticmethod + def fromCPU(device: str, x: np.ndarray): + return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data) + + def toCPU(self) -> np.ndarray: + # zero copy with as_buffer + if hasattr(self.allocator, "as_buffer"): + return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore + ret = np.empty(self.size, self.dtype.np) + if self.size > 0: + self.allocator.copyout(flat_mv(ret.data), self._buf) + return ret + class _BufferCopy(JITRunner): - # TODO: make wait work - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): - dest, src = rawbufs - assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch" - if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}") - if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): - # fast path, used on HIP between GPUs - # NOTE: it's important we use the dest device here to ensure the transfer is ready - dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize) - return - if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'): - # fast path, used on Metal in OS X Sonoma - # NOTE: this is *only* faster if the pages from disk are already loaded into memory - fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf)) - if fb: - dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize) - return - if hasattr(dest.allocator, 'as_buffer'): - # fast(ish) path, uses readinto in diskbuffers - src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) - elif hasattr(src.allocator, 'as_buffer'): - dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf)) - else: - # slow path, allocates a CPU buffer - dest.copyin(src.toCPU().data) + # TODO: make wait work + def __call__( + self, + rawbufs: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ): + dest, src = rawbufs + assert ( + dest.size == src.size and dest.dtype == src.dtype + ), "buffer copy size/dtype mismatch" + if DEBUG >= 2: + print( + f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}" + ) + if hasattr(dest.allocator, "transfer") and type(dest.allocator) is type( + src.allocator + ): + # fast path, used on HIP between GPUs + # NOTE: it's important we use the dest device here to ensure the transfer is ready + dest.allocator.transfer( + dest._buf, src._buf, dest.size * dest.dtype.itemsize + ) + return + if ( + getenv("FROM_BUFFER") + and hasattr(dest.allocator, "from_buffer") + and hasattr(dest.allocator, "transfer") + and hasattr(src.allocator, "as_buffer") + ): + # fast path, used on Metal in OS X Sonoma + # NOTE: this is *only* faster if the pages from disk are already loaded into memory + fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf)) + if fb: + dest.allocator.transfer(dest._buf, fb, dest.size * dest.dtype.itemsize) + return + if hasattr(dest.allocator, "as_buffer"): + # fast(ish) path, uses readinto in diskbuffers + src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) + elif hasattr(src.allocator, "as_buffer"): + dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf)) + else: + # slow path, allocates a CPU buffer + dest.copyin(src.toCPU().data) + + BufferCopy = _BufferCopy() + # TODO: size, dest, src are the same type. can we enforce this? class Allocator: - def alloc(self, size:int): - assert size > 0, f"alloc size must be positve, getting {size}" - return self._alloc(size) - def _alloc(self, size:int): raise NotImplementedError("need alloc") - def free(self, opaque, size:int): self._free(opaque) # if you are returning a Python object, you don't need a free - def _free(self, opaque): pass - def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin") - def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout") + def alloc(self, size: int): + assert size > 0, f"alloc size must be positve, getting {size}" + return self._alloc(size) + + def _alloc(self, size: int): + raise NotImplementedError("need alloc") + + def free(self, opaque, size: int): + self._free( + opaque + ) # if you are returning a Python object, you don't need a free + + def _free(self, opaque): + pass + + def copyin(self, dest, src: memoryview): + raise NotImplementedError("need copyin") + + def copyout(self, dest: memoryview, src): + raise NotImplementedError("need copyout") + class LRUAllocator(Allocator): # pylint: disable=abstract-method - def __init__(self): self.cache: Dict[int, Any] = defaultdict(list) - def alloc(self, size:int): - if len(c := self.cache[size]): return c.pop() - try: - return super().alloc(size) - except MemoryError: - self.free_cache() - return super().alloc(size) - def free_cache(self): - for opaques in self.cache.values(): - for opaque in opaques: self._free(opaque) - opaques.clear() - def free(self, opaque:Any, size:int): - if getenv("LRU", 1): self.cache[size].append(opaque) - else: self._free(opaque) + def __init__(self): + self.cache: Dict[int, Any] = defaultdict(list) + + def alloc(self, size: int): + if len(c := self.cache[size]): + return c.pop() + try: + return super().alloc(size) + except MemoryError: + self.free_cache() + return super().alloc(size) + + def free_cache(self): + for opaques in self.cache.values(): + for opaque in opaques: + self._free(opaque) + opaques.clear() + + def free(self, opaque: Any, size: int): + if getenv("LRU", 1): + self.cache[size].append(opaque) + else: + self._free(opaque) + class _MallocAllocator(LRUAllocator): - def _alloc(self, size:int): return (ctypes.c_uint8 * size)() - def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) - def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src)) - def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) + def _alloc(self, size: int): + return (ctypes.c_uint8 * size)() + + def as_buffer(self, src) -> memoryview: + return flat_mv(memoryview(src)) + + def copyin(self, dest, src: memoryview): + ctypes.memmove(dest, from_mv(src), len(src)) + + def copyout(self, dest: memoryview, src): + ctypes.memmove(from_mv(dest), src, len(dest)) + + MallocAllocator = _MallocAllocator() # **************** for Interpreted Devices **************** -class InterpretedASTRunner(JITRunner): - def __init__(self, ast:LazyOp, fxn:Callable): - super().__init__() - self.fxn = fxn - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float: - st = time.perf_counter() - rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals) - et = time.perf_counter() - st - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) - return et +class InterpretedASTRunner(JITRunner): + def __init__(self, ast: LazyOp, fxn: Callable): + super().__init__() + self.fxn = fxn + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + + def __call__( + self, + rawbufs: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ) -> float: + st = time.perf_counter() + rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals) + et = time.perf_counter() - st + update_stats( + f"", + self.op_estimate, + self.mem_estimate, + var_vals, + et, + len(rawbufs), + jit, + ) + return et + class Interpreted: - def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]): - self.allocator, self.fxn_for_op = allocator, fxn_for_op - self.synchronize, self.codegen, self.graph = lambda: None, None, None + def __init__(self, allocator: Allocator, fxn_for_op: Dict[Op, Callable]): + self.allocator, self.fxn_for_op = allocator, fxn_for_op + self.synchronize, self.codegen, self.graph = lambda: None, None, None - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast) + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def get_runner(self, ast: LazyOp) -> InterpretedASTRunner: + return _get_interpreted_fxn(self.fxn_for_op, ast) -def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - tglob: Dict[str, Any] = {"Variable": Variable} - @functools.lru_cache(None) - def gstr(x:Any, nm=None) -> str: - if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): - str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) - # TODO: (Variable - Variable) might create NumNode. can we remove it? - return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) - ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" - tglob[ret] = x - return ret +def _get_interpreted_fxn( + fxn_for_op: Dict[Op, Callable], ast: LazyOp +) -> InterpretedASTRunner: + if DEBUG >= 3: + from tinygrad.graph import print_tree - lines: List[str] = [] - @functools.lru_cache(None) - def _interpret_ast(ast:LazyOp) -> str: - # TODO: shortcutted store won't work with strides - if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0]) - if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + print_tree(ast) + tglob: Dict[str, Any] = {"Variable": Variable} - if ast.op in BufferOps: - if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" - else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))" - for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" - else: - tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})" + @functools.lru_cache(None) + def gstr(x: Any, nm=None) -> str: + if "Variable" in (str_arg := repr(x)) or "NumNode" in str_arg: + str_arg = re.sub( + r"Variable\(.*?\)", lambda m: f"var_vals[{str(m.group(0))}]", str_arg + ) + # TODO: (Variable - Variable) might create NumNode. can we remove it? + return re.sub(r"NumNode\((.*?)\)", r"\1", str_arg) + ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" + tglob[ret] = x + return ret - ret = f"a{len(lines)}" - lines.append(f" {ret} = {tmp}") - return ret + lines: List[str] = [] + + @functools.lru_cache(None) + def _interpret_ast(ast: LazyOp) -> str: + # TODO: shortcutted store won't work with strides + if ast.op == BufferOps.STORE: + return _interpret_ast(ast.src[0]) + if ( + TernaryOps.MULACC in fxn_for_op + and ast.op == ReduceOps.SUM + and isinstance(ast.src[0], LazyOp) + and ast.src[0].op == BinaryOps.MUL + ): + ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + + if ast.op in BufferOps: + if ast.op == ast.op == BufferOps.CONST: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" + else: + tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))" + for mop, arg in ast.arg.st.to_movement_ops(): + tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" + else: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})" + + ret = f"a{len(lines)}" + lines.append(f" {ret} = {tmp}") + return ret + + ret = _interpret_ast(ast) + src = "\n".join(["def run(inputs, var_vals):"] + lines + [f" return {ret}"]) + if DEBUG >= 4: + print( + functools.reduce( + lambda x, y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), + tglob.items(), + src, + ) + ) + exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used + return InterpretedASTRunner(ast, tglob["run"]) - ret = _interpret_ast(ast) - src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {ret}"]) - if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) - exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used - return InterpretedASTRunner(ast, tglob['run']) # **************** for Compiled Devices **************** + class CompiledASTRunner(JITRunner): - def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): - super().__init__() - if DEBUG >= 4: print(prg) - if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) - if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) - self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \ - to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {} - self.vars: List[Variable] = [] - if ast: - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - from tinygrad.lazy import vars_from_ast - self.vars = vars_from_ast(ast) - assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" + def __init__( + self, + ast: Optional[LazyOp], + name: str, + prg: str, + global_size: Optional[List[int]] = None, + local_size: Optional[List[int]] = None, + runtime_args: Optional[dict] = None, + ): + super().__init__() + if DEBUG >= 4: + print(prg) + if global_size is not None: + global_size = global_size + [1] * (3 - len(global_size)) + if local_size is not None: + local_size = local_size + [1] * (3 - len(local_size)) + ( + self.name, + self.display_name, + self.prg, + self.global_size, + self.local_size, + self.runtime_args, + ) = ( + to_function_name(name), + name, + prg, + global_size, + local_size, + runtime_args if runtime_args is not None else {}, + ) + self.vars: List[Variable] = [] + if ast: + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + from tinygrad.lazy import vars_from_ast - def build(self, compiler, runtime): - self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) - self.clprg = runtime(self.name, self.lib) - return self + self.vars = vars_from_ast(ast) + assert all( + v._val is None for v in self.vars + ), f"ASTRunner contains bound Variable {self.vars}" - def launch_dims(self, var_vals): - global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size - local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size - return global_size, local_size + def build(self, compiler, runtime): + self.lib = ( + compiler.__wrapped__(self.prg) + if getenv("DISABLE_COMPILER_CACHE") + else compiler(self.prg) + ) + self.clprg = runtime(self.name, self.lib) + return self + + def launch_dims(self, var_vals): + global_size = ( + [sym_infer(sz, var_vals) for sz in self.global_size] + if self.global_size is not None + else self.global_size + ) + local_size = ( + [sym_infer(sz, var_vals) for sz in self.local_size] + if self.local_size is not None + else self.local_size + ) + return global_size, local_size + + def __call__( + self, + rawbufs: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ) -> Optional[float]: + global_size, local_size = self.launch_dims(var_vals) + if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] + # TODO: this is copied from get_program + from tinygrad.features.search import optimize_local_size + + local_size = self.local_size = optimize_local_size( + self.clprg, global_size, rawbufs + ) + global_size = self.global_size = [ + g // l if g % l == 0 else g / l for g, l in zip(global_size, local_size) + ] + lra = self.runtime_args.copy() + if global_size: + lra["global_size"] = global_size + if local_size and "local_size" not in lra: + lra["local_size"] = local_size + et = self.clprg( + *[x._buf for x in rawbufs], + **lra, + vals=tuple(var_vals[k] for k in self.vars), + wait=wait or DEBUG >= 2, + ) + update_stats( + self.display_name, + self.op_estimate, + self.mem_estimate, + var_vals, + et, + len(rawbufs), + jit, + lra=lra, + ) + return et - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - global_size, local_size = self.launch_dims(var_vals) - if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] - # TODO: this is copied from get_program - from tinygrad.features.search import optimize_local_size - local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs) - global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] - lra = self.runtime_args.copy() - if global_size: lra['global_size'] = global_size - if local_size and 'local_size' not in lra: lra['local_size'] = local_size - et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) - update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) - return et class Compiled: - def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None): - self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph - def synchronize(self): pass # override this in your device + def __init__( + self, + allocator: Allocator, + linearizer_opts: LinearizerOptions, + renderer, + compiler, + runtime, + graph=None, + ): + ( + self.allocator, + self.linearizer_opts, + self.renderer, + self.compiler, + self.runtime, + self.graph, + ) = (allocator, linearizer_opts, renderer, compiler, runtime, graph) - def to_program(self, k:Linearizer) -> CompiledASTRunner: - k.linearize() - src, runtime_args = self.renderer(to_function_name(k.name), k.uops) - return CompiledASTRunner(k.ast, k.name, src, k.global_size, k.local_size, runtime_args).build(self.compiler, self.runtime) + def synchronize(self): + pass # override this in your device - def get_linearizer(self, ast:LazyOp) -> Linearizer: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - from tinygrad.codegen.linearizer import Linearizer - k = Linearizer(ast, self.linearizer_opts) - if not NOOPT: - if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() - if BEAM >= 1: - lins = [(("tc" if used_tensor_cores else "hc"), k)] - if used_tensor_cores: - lins.append(("hc", Linearizer(ast, self.linearizer_opts))) - lins[-1][1].hand_coded_optimizations() - kb = Linearizer(ast, self.linearizer_opts) - from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin - # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions - test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization - lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) - timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) - if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) - k = timed[0][1] - return k + def to_program(self, k: Linearizer) -> CompiledASTRunner: + k.linearize() + src, runtime_args = self.renderer(to_function_name(k.name), k.uops) + return CompiledASTRunner( + k.ast, k.name, src, k.global_size, k.local_size, runtime_args + ).build(self.compiler, self.runtime) - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast)) + def get_linearizer(self, ast: LazyOp) -> Linearizer: + if DEBUG >= 3: + from tinygrad.graph import print_tree + + print_tree(ast) + from tinygrad.codegen.linearizer import Linearizer + + k = Linearizer(ast, self.linearizer_opts) + if not NOOPT: + if not (used_tensor_cores := k.apply_tensor_cores(getenv("TC", 1))): + k.hand_coded_optimizations() + if BEAM >= 1: + lins = [(("tc" if used_tensor_cores else "hc"), k)] + if used_tensor_cores: + lins.append(("hc", Linearizer(ast, self.linearizer_opts))) + lins[-1][1].hand_coded_optimizations() + kb = Linearizer(ast, self.linearizer_opts) + from tinygrad.features.search import ( + beam_search, + time_linearizer, + bufs_from_lin, + ) + + # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions + test_rawbuffers = bufs_from_lin( + kb + ) # allocate scratch buffers for optimization + lins.append( + ( + f"beam{BEAM.value}", + beam_search( + kb, + test_rawbuffers, + BEAM.value, + bool(getenv("BEAM_ESTIMATE", 1)), + ), + ) + ) + timed = sorted( + [ + ( + nm, + tk, + time_linearizer( + tk, + test_rawbuffers, + allow_test_size=False, + clear_l2=True, + ), + ) + for nm, tk in lins + ], + key=lambda x: x[2], + ) + if DEBUG >= 1: + print( + " < ".join( + f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" + for nm, lin, tm in timed + ) + ) + k = timed[0][1] + return k + + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def get_runner(self, ast: LazyOp) -> CompiledASTRunner: + return self.to_program(self.get_linearizer(ast)) diff --git a/tinygrad/features/graph/cuda.py b/tinygrad/features/graph/cuda.py index 9832a41a3..9893f262f 100644 --- a/tinygrad/features/graph/cuda.py +++ b/tinygrad/features/graph/cuda.py @@ -5,68 +5,177 @@ from tinygrad.helpers import init_c_var, encode_args_cuda_style from tinygrad.device import CompiledASTRunner, update_stats, Buffer from tinygrad.runtime.ops_cuda import check, cu_time_execution from tinygrad.shape.symbolic import Variable -from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException +from tinygrad.jit import ( + JitItem, + get_input_replace, + get_jit_stats, + get_jc_idxs_with_updatable_launch_dims, + get_jc_idxs_with_updatable_var_vals, + GraphException, +) + class CUDAGraph: - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException + def __init__( + self, + jit_cache: List[JitItem], + input_rawbuffers: List[Buffer], + var_vals: Dict[Variable, int], + ): + if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): + raise GraphException - self.jit_cache = jit_cache - self.input_replace = get_input_replace(jit_cache, input_rawbuffers) - self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) - self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) - self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache) - self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()])) - self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params) + self.jit_cache = jit_cache + self.input_replace = get_input_replace(jit_cache, input_rawbuffers) + self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) + self.jc_idxs_with_updatable_launch_dims = ( + get_jc_idxs_with_updatable_launch_dims(jit_cache) + ) + self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals( + jit_cache + ) + self.jc_idxs_with_updatable_rawbufs = list( + set([x[0] for x in self.input_replace.keys()]) + ) + self.updatable_nodes: Dict[ + int, Tuple[Any, Any, Any] + ] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params) - self.graph = self.graph_create() - graph_node: Optional[ctypes._CData] = None + self.graph = self.graph_create() + graph_node: Optional[ctypes._CData] = None - for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] - for j,ji in enumerate(self.jit_cache): - prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) + for (j, i), input_name in self.input_replace.items(): + self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] + for j, ji in enumerate(self.jit_cache): + prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) - c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None - c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) - c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config) - graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) + c_deps = ( + (type(graph_node) * 1)(*(graph_node,)) + if graph_node is not None + else None + ) + c_kernel_input_config, c_input_params = encode_args_cuda_style( + [cast(Buffer, x)._buf for x in ji.rawbufs], + [var_vals[x] for x in prg.vars], + *self.encode_args_info(), + ) + c_node_params = self.build_kernel_node_params( + prg, + *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), + c_kernel_input_config, + ) + graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) - if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs: - self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params) + if ( + j in self.jc_idxs_with_updatable_launch_dims + or j in self.jc_idxs_with_updatable_var_vals + or j in self.jc_idxs_with_updatable_rawbufs + ): + self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params) - self.instance = self.graph_instantiate(self.graph) + self.instance = self.graph_instantiate(self.graph) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - # Update rawbuffers in the c_input_params struct. - for (j,i),input_idx in self.input_replace.items(): - setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) + def __call__( + self, + input_rawbuffers: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ) -> Optional[float]: + # Update rawbuffers in the c_input_params struct. + for (j, i), input_idx in self.input_replace.items(): + setattr( + self.updatable_nodes[j][2], f"f{i}", input_rawbuffers[input_idx]._buf + ) - # Update var_vals in the c_input_params struct. - for j in self.jc_idxs_with_updatable_var_vals: - for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): - setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v]) + # Update var_vals in the c_input_params struct. + for j in self.jc_idxs_with_updatable_var_vals: + for i, v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): + setattr( + self.updatable_nodes[j][2], + f"f{len(self.jit_cache[j].rawbufs) + i}", + var_vals[v], + ) - # Update launch dims in the c_node_params struct. - for j in self.jc_idxs_with_updatable_launch_dims: - self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)) + # Update launch dims in the c_node_params struct. + for j in self.jc_idxs_with_updatable_launch_dims: + self.set_kernel_node_launch_dims( + self.updatable_nodes[j][1], + *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals), + ) - # Update graph nodes with the updated structs. - for node, c_node_params, _ in self.updatable_nodes.values(): - self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params)) + # Update graph nodes with the updated structs. + for node, c_node_params, _ in self.updatable_nodes.values(): + self.graph_exec_kernel_node_set_params( + self.instance, node, ctypes.byref(c_node_params) + ) - et = self.graph_launch(self.instance, None, wait=wait) - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) - return et + et = self.graph_launch(self.instance, None, wait=wait) + update_stats( + f"", + self.op_estimate, + self.mem_estimate, + var_vals, + et, + buf_count=len(input_rawbuffers), + jit=jit, + num_kernels=len(self.jit_cache), + ) + return et - def __del__(self): - check(cuda.cuGraphDestroy(self.graph)) - check(cuda.cuGraphExecDestroy(self.instance)) + def __del__(self): + check(cuda.cuGraphDestroy(self.graph)) + check(cuda.cuGraphExecDestroy(self.instance)) - def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0)) - def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_node_params): return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) - def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait) - def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config) - def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size + def encode_args_info(self): + return (cuda.CUdeviceptr_v2, (1, 2, 0)) + + def graph_create(self): + return init_c_var( + cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)) + ) + + def graph_instantiate(self, graph): + return init_c_var( + cuda.CUgraphExec(), + lambda x: check( + cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0) + ), + ) + + def graph_add_kernel_node(self, graph, c_deps, c_node_params): + return init_c_var( + cuda.CUgraphNode(), + lambda x: check( + cuda.cuGraphAddKernelNode( + ctypes.byref(x), + graph, + c_deps, + ctypes.sizeof(c_deps) // 8 if c_deps else 0, + ctypes.byref(c_node_params), + ) + ), + ) + + def graph_launch(self, *args, wait=False): + return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait) + + def graph_exec_kernel_node_set_params(self, *args): + return check(cuda.cuGraphExecKernelNodeSetParams(*args)) + + def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): + return cuda.CUDA_KERNEL_NODE_PARAMS( + prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config + ) + + def set_kernel_node_launch_dims( + self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int] + ): + ( + node.blockDimX, + node.blockDimY, + node.blockDimZ, + node.gridDimX, + node.gridDimY, + node.gridDimZ, + ) = (*local_size, *global_size) diff --git a/tinygrad/features/graph/hip.py b/tinygrad/features/graph/hip.py index 3252ca6ef..82522f42f 100644 --- a/tinygrad/features/graph/hip.py +++ b/tinygrad/features/graph/hip.py @@ -5,16 +5,66 @@ from tinygrad.helpers import init_c_var from tinygrad.runtime.ops_hip import check, hip_time_execution from tinygrad.features.graph.cuda import CUDAGraph -class HIPGraph(CUDAGraph): - def __del__(self): - check(hip.hipGraphDestroy(self.graph)) - check(hip.hipGraphExecDestroy(self.instance)) - def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3)) - def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): return init_c_var(hip.hipGraphExec_t(), lambda x: check(hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_params): return init_c_var(hip.hipGraphNode_t(), lambda x: check(hip.hipGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_params)))) - def graph_launch(self, *args, wait=False): return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait) - def graph_exec_kernel_node_set_params(self, *args): return check(hip.hipGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_config): return hip.hipKernelNodeParams(hip.dim3(*local_size), c_config, ctypes.cast(prg.clprg.prg, ctypes.c_void_p), hip.dim3(*global_size), None, 0) - def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDim.x, node.blockDim.y, node.blockDim.z, node.gridDim.x, node.gridDim.y, node.gridDim.z = *local_size, *global_size +class HIPGraph(CUDAGraph): + def __del__(self): + check(hip.hipGraphDestroy(self.graph)) + check(hip.hipGraphExecDestroy(self.instance)) + + def encode_args_info(self): + return (hip.hipDeviceptr_t, (1, 2, 3)) + + def graph_create(self): + return init_c_var( + hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)) + ) + + def graph_instantiate(self, graph): + return init_c_var( + hip.hipGraphExec_t(), + lambda x: check( + hip.hipGraphInstantiate(ctypes.byref(x), graph, None, None, 0) + ), + ) + + def graph_add_kernel_node(self, graph, c_deps, c_params): + return init_c_var( + hip.hipGraphNode_t(), + lambda x: check( + hip.hipGraphAddKernelNode( + ctypes.byref(x), + graph, + c_deps, + ctypes.sizeof(c_deps) // 8 if c_deps else 0, + ctypes.byref(c_params), + ) + ), + ) + + def graph_launch(self, *args, wait=False): + return hip_time_execution(lambda: check(hip.hipGraphLaunch(*args)), enable=wait) + + def graph_exec_kernel_node_set_params(self, *args): + return check(hip.hipGraphExecKernelNodeSetParams(*args)) + + def build_kernel_node_params(self, prg, global_size, local_size, c_config): + return hip.hipKernelNodeParams( + hip.dim3(*local_size), + c_config, + ctypes.cast(prg.clprg.prg, ctypes.c_void_p), + hip.dim3(*global_size), + None, + 0, + ) + + def set_kernel_node_launch_dims( + self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int] + ): + ( + node.blockDim.x, + node.blockDim.y, + node.blockDim.z, + node.gridDim.x, + node.gridDim.y, + node.gridDim.z, + ) = (*local_size, *global_size) diff --git a/tinygrad/features/graph/metal.py b/tinygrad/features/graph/metal.py index b554c42ec..92c6e9a07 100644 --- a/tinygrad/features/graph/metal.py +++ b/tinygrad/features/graph/metal.py @@ -3,76 +3,148 @@ import numpy as np import Metal from tinygrad.helpers import dtypes, dedup, unwrap2 from tinygrad.device import Buffer, CompiledASTRunner, update_stats -from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException +from tinygrad.jit import ( + JitItem, + get_input_replace, + get_jit_stats, + get_jc_idxs_with_updatable_launch_dims, + GraphException, +) from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_metal import MetalDevice + class MetalGraph: - def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException + def __init__( + self, + device: MetalDevice, + jit_cache: List[JitItem], + input_rawbuffers: List[Buffer], + var_vals: Dict[Variable, int], + ): + if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): + raise GraphException - self.jit_cache = jit_cache - self.input_replace = get_input_replace(jit_cache, input_rawbuffers) - self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) - self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) - self.device: MetalDevice = device + self.jit_cache = jit_cache + self.input_replace = get_input_replace(jit_cache, input_rawbuffers) + self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) + self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims( + jit_cache + ) + self.device: MetalDevice = device - # create metal batch exec - icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new() - icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch)) - icb_descriptor.setInheritBuffers_(False) - icb_descriptor.setInheritPipelineState_(False) - icb_descriptor.setMaxKernelBufferBindCount_(31) - self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) - if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?") + # create metal batch exec + icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new() + icb_descriptor.setCommandTypes_( + Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch) + ) + icb_descriptor.setInheritBuffers_(False) + icb_descriptor.setInheritPipelineState_(False) + icb_descriptor.setMaxKernelBufferBindCount_(31) + self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_( + icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0) + ) + if self.icb is None: + raise GraphException( + "create indirect command buffer failed, does your system support this?" + ) - if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize) - all_resources = [self.int_buf] if len(var_vals) else [] - for j,ji in enumerate(self.jit_cache): - prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) - descriptor = Metal.MTLComputePipelineDescriptor.new() - descriptor.setComputeFunction_(prg.clprg.fxn) - descriptor.setSupportIndirectCommandBuffers_(True) - pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) - icb_command = self.icb.indirectComputeCommandAtIndex_(j) - icb_command.setComputePipelineState_(pipeline_state) - for i,b in enumerate(ji.rawbufs): - if b is not None: - icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i) - all_resources.append(b._buf) - var_vals_keys = list(var_vals.keys()) - for i,v in enumerate(prg.vars): - icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i) - if j not in self.jc_idx_with_updatable_launch_dims: - global_size, local_size = prg.launch_dims(var_vals) - icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) - icb_command.setBarrier() - self.all_resources = dedup(all_resources) - self.command_buffer: Any = None - if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32) + if len(var_vals): + self.int_buf = self.device.allocator.alloc( + len(var_vals) * dtypes.int32.itemsize + ) + all_resources = [self.int_buf] if len(var_vals) else [] + for j, ji in enumerate(self.jit_cache): + prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) + descriptor = Metal.MTLComputePipelineDescriptor.new() + descriptor.setComputeFunction_(prg.clprg.fxn) + descriptor.setSupportIndirectCommandBuffers_(True) + pipeline_state = unwrap2( + self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_( + descriptor, Metal.MTLPipelineOption(0), None, None + ) + ) + icb_command = self.icb.indirectComputeCommandAtIndex_(j) + icb_command.setComputePipelineState_(pipeline_state) + for i, b in enumerate(ji.rawbufs): + if b is not None: + icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i) + all_resources.append(b._buf) + var_vals_keys = list(var_vals.keys()) + for i, v in enumerate(prg.vars): + icb_command.setKernelBuffer_offset_atIndex_( + self.int_buf, var_vals_keys.index(v) * 4, len(ji.rawbufs) + i + ) + if j not in self.jc_idx_with_updatable_launch_dims: + global_size, local_size = prg.launch_dims(var_vals) + icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_( + Metal.MTLSize(*global_size), Metal.MTLSize(*local_size) + ) + icb_command.setBarrier() + self.all_resources = dedup(all_resources) + self.command_buffer: Any = None + if len(var_vals): + self.int_buf_view = np.frombuffer( + self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32 + ) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - # NOTE: you at least can't update the ints if this is running - if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted() - all_resources = self.all_resources + [x._buf for x in input_rawbuffers] - for (j,i),input_idx in self.input_replace.items(): - self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i) - for j in self.jc_idx_with_updatable_launch_dims: - global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) - self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) - if len(var_vals): self.int_buf_view[:] = list(var_vals.values()) - command_buffer = self.device.mtl_queue.commandBuffer() - encoder = command_buffer.computeCommandEncoder() - encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite) - encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache))) - encoder.endEncoding() - command_buffer.commit() - self.command_buffer = command_buffer - if wait: - command_buffer.waitUntilCompleted() - et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime() - else: - self.device.mtl_buffers_in_flight.append(command_buffer) - et = None - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) - return et \ No newline at end of file + def __call__( + self, + input_rawbuffers: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ) -> Optional[float]: + # NOTE: you at least can't update the ints if this is running + if ( + self.command_buffer is not None + and self.command_buffer in self.device.mtl_buffers_in_flight + ): + self.command_buffer.waitUntilCompleted() + all_resources = self.all_resources + [x._buf for x in input_rawbuffers] + for (j, i), input_idx in self.input_replace.items(): + self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_( + input_rawbuffers[input_idx]._buf, 0, i + ) + for j in self.jc_idx_with_updatable_launch_dims: + global_size, local_size = cast( + CompiledASTRunner, self.jit_cache[j].prg + ).launch_dims(var_vals) + self.icb.indirectComputeCommandAtIndex_( + j + ).concurrentDispatchThreadgroups_threadsPerThreadgroup_( + Metal.MTLSize(*global_size), Metal.MTLSize(*local_size) + ) + if len(var_vals): + self.int_buf_view[:] = list(var_vals.values()) + command_buffer = self.device.mtl_queue.commandBuffer() + encoder = command_buffer.computeCommandEncoder() + encoder.useResources_count_usage_( + all_resources, + len(all_resources), + Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite, + ) + encoder.executeCommandsInBuffer_withRange_( + self.icb, + Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)), + ) + encoder.endEncoding() + command_buffer.commit() + self.command_buffer = command_buffer + if wait: + command_buffer.waitUntilCompleted() + et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime() + else: + self.device.mtl_buffers_in_flight.append(command_buffer) + et = None + update_stats( + f"", + self.op_estimate, + self.mem_estimate, + var_vals, + et, + buf_count=len(input_rawbuffers), + jit=jit, + num_kernels=len(self.jit_cache), + ) + return et diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index 82d92f44a..dd9baaedf 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -3,147 +3,236 @@ from tinygrad.helpers import prod, IMAGE, getenv, dtypes, DEBUG # *** image Tensor function replacements *** -def image_dot(self, w): - # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) - n1, n2 = len(self.shape), len(w.shape) - assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" - bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) - cin, cout = w.shape[-2], w.shape[-1] - out_shape_t = self.shape[0:-2] + (cout,-1) - if len(self.shape) > 1: - order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2) - else: - order, out_shape_t = (0,), (cout, ) - worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2) - # NOTE: with NHWC we can remove the transposes - # bs x groups*cin x H x W - cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1)) - # groups*cout x cin x H, W - cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1)) - return image_conv2d(cx, cw, groups=groups).reshape(shape=out_shape_t).permute(order=order) +def image_dot(self, w): + # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) + n1, n2 = len(self.shape), len(w.shape) + assert ( + n1 != 0 and n2 != 0 + ), f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" + assert ( + self.shape[-1] == w.shape[-min(n2, 2)] + ), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" + bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) + cin, cout = w.shape[-2], w.shape[-1] + out_shape_t = self.shape[0:-2] + (cout, -1) + if len(self.shape) > 1: + order = tuple(range(len(self.shape) - 2)) + ( + len(self.shape) - 1, + len(self.shape) - 2, + ) + else: + order, out_shape_t = (0,), (cout,) + worder = tuple(range(len(w.shape) - 2)) + (len(w.shape) - 1, len(w.shape) - 2) + + # NOTE: with NHWC we can remove the transposes + # bs x groups*cin x H x W + cx = self.permute(order=order).reshape(shape=(bs // groups, groups * cin, -1, 1)) + # groups*cout x cin x H, W + cw = w.permute(order=worder).reshape(shape=(groups * cout, cin, 1, 1)) + return ( + image_conv2d(cx, cw, groups=groups) + .reshape(shape=out_shape_t) + .permute(order=order) + ) + def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0): - base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef + base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef - (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape - rcout = cout//groups - x, w = self, weight.reshape(groups, rcout, cin, H, W) + (bs, _, iy, ix), (cout, cin, H, W) = self.shape, weight.shape + rcout = cout // groups + x, w = self, weight.reshape(groups, rcout, cin, H, W) - # hack for non multiples of 4 on cin - if cin % 4 != 0 and not (cin == 1 and groups%4 == 0): - x = x.reshape(bs, groups, cin, iy, ix) # do this always? - added_input_channels = 4 - (cin % 4) - w = w.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(w.shape)))) - x = x.pad(tuple((0, added_input_channels) if i == 2 else (0, 0) for i in range(len(x.shape)))) - cin = cin + added_input_channels - x = x.reshape(bs, groups*cin, iy, ix) + # hack for non multiples of 4 on cin + if cin % 4 != 0 and not (cin == 1 and groups % 4 == 0): + x = x.reshape(bs, groups, cin, iy, ix) # do this always? + added_input_channels = 4 - (cin % 4) + w = w.pad( + tuple( + (0, added_input_channels) if i == 2 else (0, 0) + for i in range(len(w.shape)) + ) + ) + x = x.pad( + tuple( + (0, added_input_channels) if i == 2 else (0, 0) + for i in range(len(x.shape)) + ) + ) + cin = cin + added_input_channels + x = x.reshape(bs, groups * cin, iy, ix) - # hack for non multiples of 4 on rcout - added_output_channels = 0 - if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): - added_output_channels = 4 - (rcout % 4) - rcout += added_output_channels - cout = groups * rcout - w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape))) + # hack for non multiples of 4 on rcout + added_output_channels = 0 + if rcout % 4 != 0 and not (rcout == 1 and groups % 4 == 0): + added_output_channels = 4 - (rcout % 4) + rcout += added_output_channels + cout = groups * rcout + w = w.slice( + tuple((0, rcout) if i == 1 else (0, s) for i, s in enumerate(w.shape)) + ) - # packed (note: flipping bs and iy would make the auto-padding work) - x = x.permute(0,2,3,1) - cin_last = iy == 1 and ix == 1 - if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) - elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) - else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) + # packed (note: flipping bs and iy would make the auto-padding work) + x = x.permute(0, 2, 3, 1) + cin_last = iy == 1 and ix == 1 + if cin == 1: + w = w.reshape(cout // 4, 4, H, W).permute(0, 2, 3, 1) + elif cin_last: + w = w.reshape(cout // 4, 4, cin // 4, 4, H, W).permute(0, 4, 2, 5, 1, 3) + else: + w = w.reshape(cout // 4, 4, cin // 4, 4, H, W).permute(0, 4, 2, 5, 3, 1) - # contiguous creates the image, and early realize static weights (TODO: test for the static weight) - if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) - x, w = x.contiguous(), w.contiguous() + # contiguous creates the image, and early realize static weights (TODO: test for the static weight) + if IMAGE >= 2: + x, w = x.cast(base_image_type((bs * iy, ix * groups * cin // 4, 4))), w.cast( + base_image_type((cout // 4, H * W * cin, 4)) + ) + x, w = x.contiguous(), w.contiguous() - # expand out - rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 - cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1] - x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) - if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) - else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) + # expand out + rcin_hi, rcin_lo = cin // 4 if cin >= 4 else 1, 4 if cin >= 4 else 1 + cout_expand = [ + groups // 4 if cin == 1 else groups, + 4 if cin == 1 else 1, + rcout // 4 if rcout >= 4 else 1, + 4 if rcout >= 4 else 1, + ] + x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) + if cin_last: + w = w.reshape(cout // 4, H, rcin_hi, W, 4, rcin_lo) + else: + w = w.reshape(cout // 4, H, rcin_hi, W, rcin_lo, 4).permute(0, 1, 2, 3, 5, 4) - # padding - padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) - x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None)) + # padding + padding_ = ( + [padding] * 4 + if isinstance(padding, int) + else ( + padding + if len(padding) == 4 + else [padding[1], padding[1], padding[0], padding[0]] + ) + ) + x = x.slice( + ( + None, + (-padding_[2], x.shape[1] + padding_[3]), + (-padding_[0], x.shape[2] + padding_[1]), + None, + None, + None, + ) + ) - # prepare input - x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) - oy, ox = x.shape[4:6] - x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) - x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W) + # prepare input + x = x.permute(0, 3, 4, 5, 1, 2)._pool( + (H, W), stride, dilation + ) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) + oy, ox = x.shape[4:6] + x = x.permute(0, 4, 5, 1, 2, 3, 6, 7).reshape( + bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W + ) + x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W) - # prepare weights - w = w.permute(0,4,2,5,1,3) - w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape) + # prepare weights + w = w.permute(0, 4, 2, 5, 1, 3) + w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)).expand(x.shape) - # the conv! (+ the bias) - ret = x*w - if IMAGE >= 2: ret = ret.cast(base_image_type((bs*oy, ox*cout//4, 4))) - ret = ret.sum((-4, -3, -2, -1)) + # the conv! (+ the bias) + ret = x * w + if IMAGE >= 2: + ret = ret.cast(base_image_type((bs * oy, ox * cout // 4, 4))) + ret = ret.sum((-4, -3, -2, -1)) - # undo hack for non multiples of 4 on C.rcout - if added_output_channels != 0: - ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] - rcout -= added_output_channels - cout = groups * rcout + # undo hack for non multiples of 4 on C.rcout + if added_output_channels != 0: + ret = ret.reshape(bs, oy, ox, groups, rcout)[ + :, :, :, :, :-added_output_channels + ] + rcout -= added_output_channels + cout = groups * rcout + + # NCHW output + ret = ret.reshape(bs, oy, ox, cout).permute(0, 3, 1, 2) + return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) - # NCHW output - ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) - return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) # *** images have weird indexing requirements *** from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode -def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]: - idx = (idxy // 4) % base_shape[1] - idy = (idxy // (4 * base_shape[1])) - - if valid.min == 0 and isinstance(idxy, SumNode): - nodes = valid.nodes if isinstance(valid, AndNode) else [valid] - val_dict: Dict[Node, Any] = {} - # TODO: is this correct? should it check there's only one variable from each component? - idxy_flat_var = [(i, list(i.vars())[0]) for i in idxy.flat_components if not isinstance(i, NumNode)] - - for node in nodes: - assert isinstance(node, LtNode) - node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars() - same_sym = [i for (i, var) in idxy_flat_var if var in node_vars] - if len(same_sym) == 0: continue - first, second = sorted(same_sym)[0], sorted(node_flat)[0] - f_b = 1 if isinstance(first, Variable) else first.b - s_b = 1 if isinstance(second, Variable) else second.b - sig = -1 if s_b < 0 else 1 - key_node = sig*node.a - if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)] - val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1) - - fakes = {} - for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()): - if mnn > mxn: return (idx, idy), valid # TODO: why is this happening? - fake_var = Variable("fake_" + str(cnt), mnn, mxn) - fakes[fake_var] = key_node - idxy += multip*(fake_var - key_node) +def to_image_idx( + base_shape: Tuple[int, ...], idxy: Node, valid: Node +) -> Tuple[Tuple[Node, Node], Node]: idx = (idxy // 4) % base_shape[1] - idy = (idxy // (4 * base_shape[1])) + idy = idxy // (4 * base_shape[1]) - fake_rep = {fake: node for fake, node in fakes.items()} + if valid.min == 0 and isinstance(idxy, SumNode): + nodes = valid.nodes if isinstance(valid, AndNode) else [valid] + val_dict: Dict[Node, Any] = {} + # TODO: is this correct? should it check there's only one variable from each component? + idxy_flat_var = [ + (i, list(i.vars())[0]) + for i in idxy.flat_components + if not isinstance(i, NumNode) + ] - idx = idx.substitute(fake_rep) - idy = idy.substitute(fake_rep) + for node in nodes: + assert isinstance(node, LtNode) + node_flat, node_vars = ( + node.a.flat_components if isinstance(node.a, SumNode) else [node.a], + node.vars(), + ) + same_sym = [i for (i, var) in idxy_flat_var if var in node_vars] + if len(same_sym) == 0: + continue + first, second = sorted(same_sym)[0], sorted(node_flat)[0] + f_b = 1 if isinstance(first, Variable) else first.b + s_b = 1 if isinstance(second, Variable) else second.b + sig = -1 if s_b < 0 else 1 + key_node = sig * node.a + if key_node not in val_dict: + val_dict[key_node] = [key_node.min, key_node.max, abs(f_b // s_b)] + val_dict[key_node][(sig + 1) // 2] = sig * (node.b - 1) - idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), [] - for node in nodes: - node_vars = set(node.vars()) - if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds - # NOTE: Why does only idy is problematic? and not the idx - if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node) - valid = Variable.ands([i for i in nodes if i not in ones]) + fakes = {} + for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()): + if mnn > mxn: + return (idx, idy), valid # TODO: why is this happening? + fake_var = Variable("fake_" + str(cnt), mnn, mxn) + fakes[fake_var] = key_node + idxy += multip * (fake_var - key_node) - if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid) - return (idx, idy), valid + idx = (idxy // 4) % base_shape[1] + idy = idxy // (4 * base_shape[1]) + + fake_rep = {fake: node for fake, node in fakes.items()} + + idx = idx.substitute(fake_rep) + idy = idy.substitute(fake_rep) + + idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), [] + for node in nodes: + node_vars = set(node.vars()) + if not node_vars & (idx_vars | idy_vars): + continue # There is simplified NumNode which can not go outside the bounds + # NOTE: Why does only idy is problematic? and not the idx + if idy_vars == node_vars or idy_vars & node_vars == set(): + ones.append(node) + valid = Variable.ands([i for i in nodes if i not in ones]) + + if DEBUG >= 5: + print( + "to_image_idx", + base_shape, + idx.min, + idx.max, + idy.min, + idy.max, + idx, + idy, + valid, + ) + return (idx, idy), valid diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 9be8bb89f..3962a6102 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -3,160 +3,321 @@ import itertools, random, math, time from tinygrad.lazy import vars_from_ast from tinygrad.device import Device, Compiled, Buffer from tinygrad.ops import MemBuffer -from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing +from tinygrad.helpers import ( + prod, + ImageDType, + flatten, + DEBUG, + CACHELEVEL, + diskcache_get, + diskcache_put, + getenv, + Context, + all_int, + colored, + Timing, +) from tinygrad.codegen.linearizer import Linearizer, UOp from collections import defaultdict from tinygrad.tensor import Tensor from tinygrad.codegen.kernel import Opt, OptOps -actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)]) -actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)]) -actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)]) -actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)]) -actions += flatten([[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)]) + +actions = flatten( + [ + [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0, 2, 3, 4, 7]] + for axis in range(6) + ] +) +actions += flatten( + [[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0, 4]] for axis in range(4)] +) +actions += flatten( + [ + [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2, 3, 4, 8, 13, 16, 29]] + for axis in range(5) + ] +) +actions += flatten( + [ + [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13, 16, 29, 32, 256]] + for axis in range(3) + ] +) +actions += flatten( + [[Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32]] for axis in range(7)] +) actions += [ - Opt(op=OptOps.LOCAL, axis=0, amt=32), - Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), - Opt(op=OptOps.UPCASTMID, axis=1, amt=4), + Opt(op=OptOps.LOCAL, axis=0, amt=32), + Opt(op=OptOps.GROUP, axis=0, amt=4), + Opt(op=OptOps.GROUP, axis=0, amt=8), + Opt(op=OptOps.GROUP, axis=1, amt=8), + Opt(op=OptOps.UPCASTMID, axis=1, amt=4), ] -if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] +if getenv("NOLOCALS"): + actions += [Opt(op=OptOps.NOLOCALS)] + # returns time in seconds -def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: - key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} - if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) +def time_linearizer( + lin: Linearizer, + rawbufs: List[Buffer], + allow_test_size=True, + max_global_size=65536, + cnt=3, + disable_cache=False, + clear_l2=False, +) -> float: + key = { + "ast": str(lin.ast), + "opts": str(lin.applied_opts), + "allow_test_size": allow_test_size, + "max_global_size": max_global_size, + "clear_l2": clear_l2, + "device": Device.DEFAULT, + } + if ( + not disable_cache + and CACHELEVEL >= 2 + and (val := diskcache_get("time_linearizer", key)) is not None + ): + return min(val) - # Set the midpoint value value for var_vals to optimize shapes. - var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} - try: - lin.linearize() - prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin) - real_global_size = prg.global_size - if allow_test_size and prg.global_size and all_int(tuple(prg.global_size)): - test_global_size = prg.global_size[:] - while prod(test_global_size) > max_global_size: - for j in range(2,-1,-1): - if test_global_size[j] > 16: - test_global_size[j] //= 2 - break - factor = prod(prg.global_size) / prod(test_global_size) - prg.global_size = test_global_size - #print(real_global_size, test_global_size, factor) - else: - factor = 1 + # Set the midpoint value value for var_vals to optimize shapes. + var_vals = {k: (k.max + k.min) // 2 for k in vars_from_ast(lin.ast)} + try: + lin.linearize() + prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin) + real_global_size = prg.global_size + if allow_test_size and prg.global_size and all_int(tuple(prg.global_size)): + test_global_size = prg.global_size[:] + while prod(test_global_size) > max_global_size: + for j in range(2, -1, -1): + if test_global_size[j] > 16: + test_global_size[j] //= 2 + break + factor = prod(prg.global_size) / prod(test_global_size) + prg.global_size = test_global_size + # print(real_global_size, test_global_size, factor) + else: + factor = 1 - # TODO: this is copied from prg.__call__ - global_size, local_size = prg.launch_dims(var_vals) - prg.global_size = real_global_size - if global_size is not None and prg.global_size is not None and local_size is None and all_int(tuple(prg.global_size)): - local_size = optimize_local_size(prg.clprg, global_size, rawbufs) - global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] + # TODO: this is copied from prg.__call__ + global_size, local_size = prg.launch_dims(var_vals) + prg.global_size = real_global_size + if ( + global_size is not None + and prg.global_size is not None + and local_size is None + and all_int(tuple(prg.global_size)) + ): + local_size = optimize_local_size(prg.clprg, global_size, rawbufs) + global_size = [ + g // l if g % l == 0 else g / l for g, l in zip(global_size, local_size) + ] - lra = prg.runtime_args.copy() - if global_size: lra['global_size'] = global_size - if local_size: lra['local_size'] = local_size + lra = prg.runtime_args.copy() + if global_size: + lra["global_size"] = global_size + if local_size: + lra["local_size"] = local_size + + tms = [] + for _ in range(cnt): + if clear_l2: + # TODO: this is too small for many L2 caches + with Context(DEBUG=0): + Tensor.rand(1024, 1024).realize() + tms.append( + prg.clprg( + *[x._buf for x in rawbufs], *var_vals.values(), **lra, wait=True + ) + * factor + ) + except Exception: + if DEBUG >= 4: + import traceback + + traceback.print_exc() + print("FAILED") + print(lin.ast) + print(lin.applied_opts) + tms = [float("inf")] + if CACHELEVEL >= 2: + diskcache_put("time_linearizer", key, tms) + return min(tms) - tms = [] - for _ in range(cnt): - if clear_l2: - # TODO: this is too small for many L2 caches - with Context(DEBUG=0): Tensor.rand(1024,1024).realize() - tms.append(prg.clprg(*[x._buf for x in rawbufs], *var_vals.values(), **lra, wait=True)*factor) - except Exception: - if DEBUG >= 4: - import traceback - traceback.print_exc() - print("FAILED") - print(lin.ast) - print(lin.applied_opts) - tms = [float('inf')] - if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) - return min(tms) # get (scrap) buffers for timing the linearizer -def bufs_from_lin(lin:Linearizer) -> List[Buffer]: - bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list) - for x in lin.membufs: bufsts[x.idx].append(x) - rawbufs:List[Optional[Buffer]] = [None]*len(bufsts) - for k,lx in bufsts.items(): - rawbufs[k] = Buffer(Device.DEFAULT, prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype) - assert all(r is not None for r in rawbufs) - return cast(List[Buffer], rawbufs) +def bufs_from_lin(lin: Linearizer) -> List[Buffer]: + bufsts: DefaultDict[int, List[MemBuffer]] = defaultdict(list) + for x in lin.membufs: + bufsts[x.idx].append(x) + rawbufs: List[Optional[Buffer]] = [None] * len(bufsts) + for k, lx in bufsts.items(): + rawbufs[k] = Buffer( + Device.DEFAULT, + prod(lx[0].dtype.shape) + if isinstance(lx[0].dtype, ImageDType) + else max(y.st.size() for y in lx), + lx[0].dtype, + ) + assert all(r is not None for r in rawbufs) + return cast(List[Buffer], rawbufs) + # get dictionary of all possible actions -def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]: - acted_lins = {0:lin} if include_0 else {} - for i,a in enumerate(actions): - if a.axis is not None and a.axis >= lin.shape_len: continue - if a.axis is not None and lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue - lin2 = lin.copy() - try: - lin2.apply_opt(a) - up, lcl = 1, 1 - for s,c in zip(lin2.full_shape, lin2.colors()): - if c in {"magenta", "yellow"}: up *= s - if c in {"cyan", "green", "white"}: lcl *= s - if up > 256 or lcl > 256: continue - acted_lins[i+1] = lin2 - except Exception: - pass - return acted_lins +def get_linearizer_actions(lin: Linearizer, include_0=True) -> Dict[int, Linearizer]: + acted_lins = {0: lin} if include_0 else {} + for i, a in enumerate(actions): + if a.axis is not None and a.axis >= lin.shape_len: + continue + if ( + a.axis is not None + and lin.full_shape[a.axis] == a.amt + and Opt(a.op, a.axis, 0) in actions + ): + continue + lin2 = lin.copy() + try: + lin2.apply_opt(a) + up, lcl = 1, 1 + for s, c in zip(lin2.full_shape, lin2.colors()): + if c in {"magenta", "yellow"}: + up *= s + if c in {"cyan", "green", "white"}: + lcl *= s + if up > 256 or lcl > 256: + continue + acted_lins[i + 1] = lin2 + except Exception: + pass + return acted_lins -def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) -def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: - key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT} - if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1: - ret = lin.copy() - for o in val[len(lin.applied_opts):]: ret.apply_opt(o) - return ret +def tuplize_uops(uops: List[UOp]) -> Tuple: + return tuple( + [(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops] + ) - # init the BEAM with the base linearizer - beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))] - # NOTE: real uops use a weird compare method that's only valid inside a linearizer - seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} +def beam_search(lin: Linearizer, rawbufs, amt: int, allow_test_size=True) -> Linearizer: + key = { + "ast": str(lin.ast), + "amt": amt, + "allow_test_size": allow_test_size, + "device": Device.DEFAULT, + } + if ( + (val := diskcache_get("beam_search", key)) is not None + and not getenv("IGNORE_BEAM_CACHE") + and CACHELEVEL >= 1 + ): + ret = lin.copy() + for o in val[len(lin.applied_opts) :]: + ret.apply_opt(o) + return ret - exiting, st = False, time.perf_counter() - while not exiting: - with Timing("linearize: ", enabled=DEBUG>=3): - acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) + # init the BEAM with the base linearizer + beam: List[Tuple[Linearizer, float]] = [ + (lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size)) + ] - # linearize all - for x in acted_lins: x.linearize() + # NOTE: real uops use a weird compare method that's only valid inside a linearizer + seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} - # dedup with uops - acted_lins_dedup = [] - for lin in acted_lins: - tuops = tuplize_uops(lin.uops) - if tuops in seen_uops: continue - seen_uops[tuops] = tuple(lin.applied_opts) - acted_lins_dedup.append(lin) + exiting, st = False, time.perf_counter() + while not exiting: + with Timing("linearize: ", enabled=DEBUG >= 3): + acted_lins = flatten( + [ + get_linearizer_actions(lin, include_0=False).values() + for lin, _ in beam + ] + ) - with Timing("compile: ",enabled=DEBUG>=3): - # time linearizers - timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup] - opts = sorted(timed_lins, key=lambda x: x[1]) + # linearize all + for x in acted_lins: + x.linearize() - # done - exiting = len(opts) == 0 or beam[0][1] <= opts[0][1] - if not exiting: beam = opts[:amt] - if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape()) + # dedup with uops + acted_lins_dedup = [] + for lin in acted_lins: + tuops = tuplize_uops(lin.uops) + if tuops in seen_uops: + continue + seen_uops[tuops] = tuple(lin.applied_opts) + acted_lins_dedup.append(lin) - if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) - if DEBUG >= 3: print(beam[0][0].applied_opts) - return beam[0][0] + with Timing("compile: ", enabled=DEBUG >= 3): + # time linearizers + timed_lins: List[Tuple[Linearizer, float]] = [ + (v, time_linearizer(v, rawbufs, allow_test_size=allow_test_size)) + for v in acted_lins_dedup + ] + opts = sorted(timed_lins, key=lambda x: x[1]) -def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]: - test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs - MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024 - local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] - local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice - def try_exec(local_size): - try: - return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) - except Exception: - return float('inf') - ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) - assert not math.isinf(ret[0]), "all optimize_local_size exec failed" - return ret[1] + # done + exiting = len(opts) == 0 or beam[0][1] <= opts[0][1] + if not exiting: + beam = opts[:amt] + if DEBUG >= 2: + print( + f"{time.perf_counter() - st:7.2f}s:", + colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), + f"from {len(acted_lins):3d} -> {len(opts):3d} actions", + beam[0][0].colored_shape(), + ) + + if CACHELEVEL >= 1: + diskcache_put("beam_search", key, beam[0][0].applied_opts) + if DEBUG >= 3: + print(beam[0][0].applied_opts) + return beam[0][0] + + +def optimize_local_size( + clprg: Callable, global_size: List[int], rawbufs: List[Buffer] +) -> List[int]: + test_rawbuffers = ( + [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] + if rawbufs[0] in rawbufs[1:] + else rawbufs + ) + MAX_WORKGROUP = ( + clprg.max_work_group_size() if hasattr(clprg, "max_work_group_size") else 1024 + ) + local_dims = [ + [ + x + for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) + if x <= sz + ] + for sz in global_size + ] + local_sizes = [ + list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP + ] * 2 # try each valid size twice + + def try_exec(local_size): + try: + return clprg( + *[x._buf for x in test_rawbuffers], + global_size=[ + g // l if g % l == 0 else g / l + for g, l in zip(global_size, local_size) + ], + local_size=local_size, + wait=True, + ) + except Exception: + return float("inf") + + ret = min( + [ + (try_exec(local_size), local_size) + for local_size in random.sample(local_sizes, len(local_sizes)) + ] + ) + assert not math.isinf(ret[0]), "all optimize_local_size exec failed" + return ret[1] diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 3963c7240..19f9538f9 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -1,7 +1,19 @@ import os, atexit, functools from collections import defaultdict from typing import Dict, List -from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp +from tinygrad.ops import ( + ScheduleItem, + UnaryOps, + BinaryOps, + ReduceOps, + MovementOps, + LoadOps, + BufferOps, + TernaryOps, + Op, + OpType, + LazyOp, +) from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.shape.shapetracker import ShapeTracker @@ -11,108 +23,194 @@ from tinygrad.shape.symbolic import NumNode cnts: Dict[OpType, int] = defaultdict(int) if DEBUG >= 2: - def print_globalcounters(): - if GlobalCounters.time_sum_s == 0: return - print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", - f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") - atexit.register(print_globalcounters) + + def print_globalcounters(): + if GlobalCounters.time_sum_s == 0: + return + print( + f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", + f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms", + ) + + atexit.register(print_globalcounters) if GRAPH: - import networkx as nx - G = nx.DiGraph() - def save_graph_exit(): - for k,v in cnts.items(): print(k, v) - print("saving", G) - nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') - # -Gnslimit=100 can make it finish, but you won't like results - os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg') - atexit.register(save_graph_exit) + import networkx as nx + + G = nx.DiGraph() + + def save_graph_exit(): + for k, v in cnts.items(): + print(k, v) + print("saving", G) + nx.drawing.nx_pydot.write_dot(G, f"{GRAPHPATH}.dot") + # -Gnslimit=100 can make it finish, but you won't like results + os.system(f"dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg") + + atexit.register(save_graph_exit) node_count = 0 + + def nm(x): - global node_count - if not hasattr(x, 'node_id'): - setattr(x, 'node_id', node_count) - node_count += 1 - return x.node_id + global node_count + if not hasattr(x, "node_id"): + setattr(x, "node_id", node_count) + node_count += 1 + return x.node_id + def get_sop(op: List[Op]): - op = [x for x in op if x not in BufferOps] - if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1]) - if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1]) - return str(len(op)) + op = [x for x in op if x not in BufferOps] + if len(op) <= 2: + return ".".join([str(y).split(".")[1] for y in op][::-1]) + if len(op) <= 6: + return ".".join([str(y).split(".")[1][0:3] for y in op][::-1]) + return str(len(op)) + def str_dtype(dtyp): - ret = str(dtyp)[7:] - return "" if ret == 'float' else f"\n{ret}" + ret = str(dtyp)[7:] + return "" if ret == "float" else f"\n{ret}" + @functools.lru_cache(None) -def add_st_node(nmx, nmo, label, st:ShapeTracker): - global node_count - inter_node = node_count - node_count += 1 - offset = st.expr_node(NumNode(0))[0] - G.add_node(inter_node, style='filled', fillcolor="#80ff8080", color="black", label=f"{st.shape}\n{st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")) - G.add_edge(nmx, inter_node, color='#00000060') - G.add_edge(inter_node, nmo, label=label, color='#00000060') +def add_st_node(nmx, nmo, label, st: ShapeTracker): + global node_count + inter_node = node_count + node_count += 1 + offset = st.expr_node(NumNode(0))[0] + G.add_node( + inter_node, + style="filled", + fillcolor="#80ff8080", + color="black", + label=f"{st.shape}\n{st.real_strides()}" + + (f"\n{offset}" if offset != 0 else ""), + ) + G.add_edge(nmx, inter_node, color="#00000060") + G.add_edge(inter_node, nmo, label=label, color="#00000060") + + +logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None + -logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None def log_schedule_item(si: ScheduleItem): - if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n") - if not DEBUG and not GRAPH: return - if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base)) - if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return + if logops and si.ast.op not in LoadOps: + logops.write(str(si.ast) + "\n") + if not DEBUG and not GRAPH: + return + if si.ast.op == LoadOps.CONTIGUOUS: + setattr(si.out, "node_id", nm(si.inputs[0].base)) + if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: + return - op: List[Op] = [x.op for x in si.ast.get_lazyops()] - oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps] - optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) - cnts[optype] += 1 - if GRAPH: - assert si.out.base == si.out, "all outputs based" - top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'} + op: List[Op] = [x.op for x in si.ast.get_lazyops()] + oporder = [ + LoadOps, + TernaryOps, + ReduceOps, + BinaryOps, + UnaryOps, + MovementOps, + BufferOps, + ] + optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) + cnts[optype] += 1 + if GRAPH: + assert si.out.base == si.out, "all outputs based" + top_colors = { + LoadOps: "#FFFFa0", + UnaryOps: "#c0c0c0", + ReduceOps: "#8080ff", + BinaryOps: "#c0c0c0", + MovementOps: "#80ff80", + TernaryOps: "#c0c0c0", + BufferOps: "#FF8080", + } - # get inputs for shapetrackers - input_to_st = defaultdict(list) - for lo in si.ast.get_lazyops(): - if lo.op != BufferOps.LOAD: continue - input_to_st[si.inputs[lo.arg.idx-1]].append(lo.arg.st) + # get inputs for shapetrackers + input_to_st = defaultdict(list) + for lo in si.ast.get_lazyops(): + if lo.op != BufferOps.LOAD: + continue + input_to_st[si.inputs[lo.arg.idx - 1]].append(lo.arg.st) - # add them to the graph, potentially with a movement op separating them - for x in input_to_st: - for st in dedup(input_to_st[x]): - if st.contiguous: - G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060') - else: - add_st_node(nm(x), nm(si.out), get_sop(op), st) - if 'label' not in G.nodes[nm(x)]: - G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype) + # add them to the graph, potentially with a movement op separating them + for x in input_to_st: + for st in dedup(input_to_st[x]): + if st.contiguous: + G.add_edge(nm(x), nm(si.out), label=get_sop(op), color="#00000060") + else: + add_st_node(nm(x), nm(si.out), get_sop(op), st) + if "label" not in G.nodes[nm(x)]: + G.nodes[nm(x)]["label"] = str(x.shape) + str_dtype(si.out.dtype) - if nm(si.out) not in G.nodes: G.add_node(nm(si.out)) + if nm(si.out) not in G.nodes: + G.add_node(nm(si.out)) + + G.nodes[nm(si.out)]["label"] = ( + ( + str(set(x.shape for x in si.inputs)) + "\n" + str(si.out.shape) + if optype == ReduceOps + else str(si.out.shape) + ) + + str_dtype(si.out.dtype) + + (f"\n{si.ast.op}" if si.ast.op in LoadOps else "") + ) + G.nodes[nm(si.out)]["fillcolor"] = top_colors[optype] + G.nodes[nm(si.out)]["color"] = "black" + G.nodes[nm(si.out)]["style"] = "filled" - G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "") - G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype] - G.nodes[nm(si.out)]['color'] = 'black' - G.nodes[nm(si.out)]['style'] = 'filled' def _tree(lazydata, prefix=""): - if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") - if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] - lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] - childs = [_tree(c) for c in lazydata.src[:]] - for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] - return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] + if type(lazydata).__name__ == "LazyBuffer": + return ( + [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] + if (lazydata.realized) + else _tree(lazydata.op, "LB ") + ) + if len(lazydata.src) == 0: + return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] + lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] + childs = [_tree(c) for c in lazydata.src[:]] + for c in childs[:-1]: + lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] + return lines + [" ┗" + childs[-1][0]] + [" " + l for l in childs[-1][1:]] -def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))])) -def graph_uops(uops:List[UOp]): - import networkx as nx - colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", - UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", - UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} - G = nx.DiGraph() - for u in uops: - if u.uop == UOps.END: continue - G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) - for v in u.vin: G.add_edge(uops.index(v), uops.index(u)) - GRAPHPATH = "/tmp/uops" - nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') - os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg') +def print_tree(lazydata: LazyOp): + print("\n".join([f"{str(i).rjust(3)} {s}" for i, s in enumerate(_tree(lazydata))])) + + +def graph_uops(uops: List[UOp]): + import networkx as nx + + colors = { + UOps.ALU: "#ffffc0", + UOps.LOAD: "#ffc0c0", + UOps.STORE: "#c0ffc0", + UOps.SPECIAL: "#c0c0ff", + UOps.CONST: "#e0e0e0", + UOps.DEFINE_GLOBAL: "#ffe0b0", + UOps.DEFINE_LOCAL: "#ffe0d0", + UOps.DEFINE_ACC: "#f0ffe0", + UOps.LOOP: "#c8a0e0", + UOps.PHI: "#e0ffc0", + UOps.BARRIER: "#ff8080", + UOps.IF: "#c8b0c0", + } + G = nx.DiGraph() + for u in uops: + if u.uop == UOps.END: + continue + G.add_node( + uops.index(u), + label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", + style="filled", + fillcolor=colors.get(u.uop, "#ffffff"), + ) + for v in u.vin: + G.add_edge(uops.index(v), uops.index(u)) + GRAPHPATH = "/tmp/uops" + nx.drawing.nx_pydot.write_dot(G, f"{GRAPHPATH}.dot") + os.system(f"dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg") diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 39fc3c5b0..80345f2aa 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -3,311 +3,702 @@ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, import numpy as np from urllib import request from tqdm import tqdm -from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable -if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 - from typing_extensions import TypeGuard +from typing import ( + Dict, + Tuple, + Union, + List, + NamedTuple, + Final, + ClassVar, + Optional, + Iterable, + Any, + TypeVar, + TYPE_CHECKING, + Callable, +) + +if ( + TYPE_CHECKING +): # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 + from typing_extensions import TypeGuard T = TypeVar("T") U = TypeVar("U") + + # NOTE: it returns int 1 if x is empty regardless of the type of x -def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1) +def prod(x: Iterable[T]) -> Union[T, int]: + return functools.reduce(operator.__mul__, x, 1) + # NOTE: helpers is not allowed to import from anything else in tinygrad OSX = platform.system() == "Darwin" CI = os.getenv("CI", "") != "" -def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order -def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x -def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python -def all_same(items:List[T]): return all(x == items[0] for x in items) -def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t) -def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line -def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s) -def ansilen(s:str): return len(ansistrip(s)) -def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x -def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist] -def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) -def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst -def round_up(num, amt:int): return (num+amt-1)//amt * amt -def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: - assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" - return {k:v for d in ds for k,v in d.items()} -def partition(lst:List[T], fxn:Callable[[T],bool]): - a:List[T] = [] - b:List[T] = [] - for s in lst: (a if fxn(s) else b).append(s) - return a,b -def unwrap(x:Optional[T]) -> T: - assert x is not None - return x -def unwrap2(x:Tuple[T,Any]) -> T: - ret, err = x - assert err is None, str(err) - return ret + +def dedup(x: Iterable[T]): + return list(dict.fromkeys(x)) # retains list order + + +def argfix(*x): + return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x + + +def argsort(x): + return type(x)( + sorted(range(len(x)), key=x.__getitem__) + ) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python + + +def all_same(items: List[T]): + return all(x == items[0] for x in items) + + +def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: + return all(isinstance(s, int) for s in t) + + +def colored(st, color: Optional[str], background=False): + return ( + f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" + if color is not None + else st + ) # replace the termcolor library with one line + + +def ansistrip(s: str): + return re.sub("\x1b\\[(K|.*?m)", "", s) + + +def ansilen(s: str): + return len(ansistrip(s)) + + +def make_pair(x: Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: + return (x,) * cnt if isinstance(x, int) else x + + +def flatten(l: Iterable[Iterable[T]]): + return [item for sublist in l for item in sublist] + + +def fromimport(mod, frm): + return getattr(__import__(mod, fromlist=[frm]), frm) + + +def strip_parens(fst: str): + return ( + fst[1:-1] + if fst[0] == "(" + and fst[-1] == ")" + and fst[1:-1].find("(") <= fst[1:-1].find(")") + else fst + ) + + +def round_up(num, amt: int): + return (num + amt - 1) // amt * amt + + +def merge_dicts(ds: Iterable[Dict[T, U]]) -> Dict[T, U]: + assert len(kvs := set([(k, v) for d in ds for k, v in d.items()])) == len( + set(kv[0] for kv in kvs) + ), f"cannot merge, {kvs} contains different values for the same key" + return {k: v for d in ds for k, v in d.items()} + + +def partition(lst: List[T], fxn: Callable[[T], bool]): + a: List[T] = [] + b: List[T] = [] + for s in lst: + (a if fxn(s) else b).append(s) + return a, b + + +def unwrap(x: Optional[T]) -> T: + assert x is not None + return x + + +def unwrap2(x: Tuple[T, Any]) -> T: + ret, err = x + assert err is None, str(err) + return ret + + def get_child(obj, key): - for k in key.split('.'): - if k.isnumeric(): obj = obj[int(k)] - elif isinstance(obj, dict): obj = obj[k] - else: obj = getattr(obj, k) - return obj + for k in key.split("."): + if k.isnumeric(): + obj = obj[int(k)] + elif isinstance(obj, dict): + obj = obj[k] + else: + obj = getattr(obj, k) + return obj + @functools.lru_cache(maxsize=None) -def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)]) +def to_function_name(s: str): + return "".join( + [ + c if c in (string.ascii_letters + string.digits + "_") else f"{ord(c):02X}" + for c in ansistrip(s) + ] + ) + + @functools.lru_cache(maxsize=None) -def getenv(key:str, default=0): return type(default)(os.getenv(key, default)) -def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() +def getenv(key: str, default=0): + return type(default)(os.getenv(key, default)) + + +def temp(x: str) -> str: + return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() + class Context(contextlib.ContextDecorator): - stack: ClassVar[List[dict[str, int]]] = [{}] - def __init__(self, **kwargs): self.kwargs = kwargs - def __enter__(self): - Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state. - for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state. - Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later. - def __exit__(self, *args): - for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value) + stack: ClassVar[List[dict[str, int]]] = [{}] + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __enter__(self): + Context.stack[-1] = { + k: o.value for k, o in ContextVar._cache.items() + } # Store current state. + for k, v in self.kwargs.items(): + ContextVar._cache[k].value = v # Update to new temporary state. + Context.stack.append( + self.kwargs + ) # Store the temporary state so we know what to undo later. + + def __exit__(self, *args): + for k in Context.stack.pop(): + ContextVar._cache[k].value = Context.stack[-1].get( + k, ContextVar._cache[k].value + ) + class ContextVar: - _cache: ClassVar[Dict[str, ContextVar]] = {} - value: int - def __new__(cls, key, default_value): - if key in ContextVar._cache: return ContextVar._cache[key] - instance = ContextVar._cache[key] = super().__new__(cls) - instance.value = getenv(key, default_value) - return instance - def __bool__(self): return bool(self.value) - def __ge__(self, x): return self.value >= x - def __gt__(self, x): return self.value > x - def __lt__(self, x): return self.value < x + _cache: ClassVar[Dict[str, ContextVar]] = {} + value: int -DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) + def __new__(cls, key, default_value): + if key in ContextVar._cache: + return ContextVar._cache[key] + instance = ContextVar._cache[key] = super().__new__(cls) + instance.value = getenv(key, default_value) + return instance + + def __bool__(self): + return bool(self.value) + + def __ge__(self, x): + return self.value >= x + + def __gt__(self, x): + return self.value > x + + def __lt__(self, x): + return self.value < x + + +DEBUG, IMAGE, BEAM, NOOPT = ( + ContextVar("DEBUG", 0), + ContextVar("IMAGE", 0), + ContextVar("BEAM", 0), + ContextVar("NOOPT", 0), +) GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net") + class Timing(contextlib.ContextDecorator): - def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled - def __enter__(self): self.st = time.perf_counter_ns() - def __exit__(self, *exc): - self.et = time.perf_counter_ns() - self.st - if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else "")) + def __init__(self, prefix="", on_exit=None, enabled=True): + self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled + + def __enter__(self): + self.st = time.perf_counter_ns() + + def __exit__(self, *exc): + self.et = time.perf_counter_ns() - self.st + if self.enabled: + print( + f"{self.prefix}{self.et*1e-6:.2f} ms" + + (self.on_exit(self.et) if self.on_exit else "") + ) + class Profiling(contextlib.ContextDecorator): - def __init__(self, enabled=True, sort='cumtime', frac=0.2): self.enabled, self.sort, self.frac = enabled, sort, frac - def __enter__(self): - self.pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6) - if self.enabled: self.pr.enable() - def __exit__(self, *exc): - if self.enabled: - self.pr.disable() - pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(self.frac) + def __init__(self, enabled=True, sort="cumtime", frac=0.2): + self.enabled, self.sort, self.frac = enabled, sort, frac + + def __enter__(self): + self.pr = cProfile.Profile(timer=lambda: int(time.time() * 1e9), timeunit=1e-6) + if self.enabled: + self.pr.enable() + + def __exit__(self, *exc): + if self.enabled: + self.pr.disable() + pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats( + self.frac + ) + # **** tinygrad now supports dtypes! ***** + # TODO: migrate this from NamedTuple -> dataclass class DType(NamedTuple): - priority: int # this determines when things get upcasted - itemsize: int - name: str - np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project - sz: int = 1 - def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}" - def vec(self, sz:int): - assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}" - return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz) - def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self + priority: int # this determines when things get upcasted + itemsize: int + name: str + np: Optional[ + type + ] # TODO: someday this will be removed with the "remove numpy" project + sz: int = 1 + + def __repr__(self): + return ( + f"dtypes.{INVERSE_DTYPES_DICT[self]}" + if self.sz == 1 + else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}" + ) + + def vec(self, sz: int): + assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}" + return DType(self.priority, self.itemsize * sz, self.name + str(sz), None, sz) + + def scalar(self): + return DTYPES_DICT[self.name[: -len(str(self.sz))]] if self.sz > 1 else self + # dependent typing? class ImageDType(DType): - def __new__(cls, priority, itemsize, name, np, shape): - return super().__new__(cls, priority, itemsize, name, np) - def __init__(self, priority, itemsize, name, np, shape): - self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape - super().__init__() - def __repr__(self): return f"dtypes.{self.name}({self.shape})" - # TODO: fix this to not need these - def __hash__(self): return hash((super().__hash__(), self.shape)) - def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape - def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape + def __new__(cls, priority, itemsize, name, np, shape): + return super().__new__(cls, priority, itemsize, name, np) + + def __init__(self, priority, itemsize, name, np, shape): + self.shape: Tuple[ + int, ... + ] = shape # arbitrary arg for the dtype, used in image for the shape + super().__init__() + + def __repr__(self): + return f"dtypes.{self.name}({self.shape})" + + # TODO: fix this to not need these + def __hash__(self): + return hash((super().__hash__(), self.shape)) + + def __eq__(self, x): + return super().__eq__(x) and self.shape == x.shape + + def __ne__(self, x): + return super().__ne__(x) or self.shape != x.shape + class PtrDType(DType): - def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz) - def __repr__(self): return f"ptr.{super().__repr__()}" + def __new__(cls, dt: DType): + return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz) + + def __repr__(self): + return f"ptr.{super().__repr__()}" + class dtypes: - @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool - def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) - @staticmethod - def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4)) - @staticmethod - def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) - @staticmethod - def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] - @staticmethod - def fields() -> Dict[str, DType]: return DTYPES_DICT - bool: Final[DType] = DType(0, 1, "bool", np.bool_) - float16: Final[DType] = DType(9, 2, "half", np.float16) - half = float16 - float32: Final[DType] = DType(10, 4, "float", np.float32) - float = float32 - float64: Final[DType] = DType(11, 8, "double", np.float64) - double = float64 - int8: Final[DType] = DType(1, 1, "char", np.int8) - int16: Final[DType] = DType(3, 2, "short", np.int16) - int32: Final[DType] = DType(5, 4, "int", np.int32) - int = int32 - int64: Final[DType] = DType(7, 8, "long", np.int64) - uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) - uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) - uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) - uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) + @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool + def is_int(x: DType) -> bool: + return x in ( + dtypes.int8, + dtypes.int16, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + ) - # NOTE: bfloat16 isn't supported in numpy - bfloat16: Final[DType] = DType(9, 2, "__bf16", None) + @staticmethod + def is_float(x: DType) -> bool: + return x in ( + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.half.vec(4), + dtypes.float.vec(2), + dtypes.float.vec(4), + ) - # NOTE: these are internal dtypes, should probably check for that - _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) + @staticmethod + def is_unsigned(x: DType) -> bool: + return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + + @staticmethod + def from_np(x) -> DType: + return DTYPES_DICT[np.dtype(x).name] + + @staticmethod + def fields() -> Dict[str, DType]: + return DTYPES_DICT + + bool: Final[DType] = DType(0, 1, "bool", np.bool_) + float16: Final[DType] = DType(9, 2, "half", np.float16) + half = float16 + float32: Final[DType] = DType(10, 4, "float", np.float32) + float = float32 + float64: Final[DType] = DType(11, 8, "double", np.float64) + double = float64 + int8: Final[DType] = DType(1, 1, "char", np.int8) + int16: Final[DType] = DType(3, 2, "short", np.int16) + int32: Final[DType] = DType(5, 4, "int", np.int32) + int = int32 + int64: Final[DType] = DType(7, 8, "long", np.int64) + uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) + uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) + uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) + uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) + + # NOTE: bfloat16 isn't supported in numpy + bfloat16: Final[DType] = DType(9, 2, "__bf16", None) + + # NOTE: these are internal dtypes, should probably check for that + _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None) + + # NOTE: these are image dtypes + @staticmethod + def imageh(shp): + return ImageDType(100, 2, "imageh", np.float16, shp) + + @staticmethod + def imagef(shp): + return ImageDType(100, 4, "imagef", np.float32, shp) - # NOTE: these are image dtypes - @staticmethod - def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp) - @staticmethod - def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp) # HACK: staticmethods are not callable in 3.8 so we have to compare the class -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod} -INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()} +DTYPES_DICT = { + k: v + for k, v in dtypes.__dict__.items() + if not k.startswith("__") and not callable(v) and v.__class__ is not staticmethod +} +INVERSE_DTYPES_DICT = {v: k for k, v in DTYPES_DICT.items()} + class GlobalCounters: - global_ops: ClassVar[int] = 0 - global_mem: ClassVar[int] = 0 - time_sum_s: ClassVar[float] = 0.0 - kernel_count: ClassVar[int] = 0 - mem_used: ClassVar[int] = 0 # NOTE: this is not reset - mem_cached: ClassVar[int] = 0 # NOTE: this is not reset - @staticmethod - def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0 + global_ops: ClassVar[int] = 0 + global_mem: ClassVar[int] = 0 + time_sum_s: ClassVar[float] = 0.0 + kernel_count: ClassVar[int] = 0 + mem_used: ClassVar[int] = 0 # NOTE: this is not reset + mem_cached: ClassVar[int] = 0 # NOTE: this is not reset + + @staticmethod + def reset(): + ( + GlobalCounters.global_ops, + GlobalCounters.global_mem, + GlobalCounters.time_sum_s, + GlobalCounters.kernel_count, + ) = (0, 0, 0.0, 0) + # *** universal database cache *** -_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")) -CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db"))) +_cache_dir: str = getenv( + "XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache") +) +CACHEDB: str = getenv( + "CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")) +) CACHELEVEL = getenv("CACHELEVEL", 2) VERSION = 10 _db_connection = None -def db_connection(): - global _db_connection - if _db_connection is None: - os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True) - _db_connection = sqlite3.connect(CACHEDB) - if DEBUG >= 7: _db_connection.set_trace_callback(print) - return _db_connection -def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any: - if CACHELEVEL == 0: return None - if isinstance(key, (str,int)): key = {"key": key} - conn = db_connection() - cur = conn.cursor() - try: - res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values())) - except sqlite3.OperationalError: - return None # table doesn't exist - if (val:=res.fetchone()) is not None: return pickle.loads(val[0]) - return None + +def db_connection(): + global _db_connection + if _db_connection is None: + os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True) + _db_connection = sqlite3.connect(CACHEDB) + if DEBUG >= 7: + _db_connection.set_trace_callback(print) + return _db_connection + + +def diskcache_get(table: str, key: Union[Dict, str, int]) -> Any: + if CACHELEVEL == 0: + return None + if isinstance(key, (str, int)): + key = {"key": key} + conn = db_connection() + cur = conn.cursor() + try: + res = cur.execute( + f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", + tuple(key.values()), + ) + except sqlite3.OperationalError: + return None # table doesn't exist + if (val := res.fetchone()) is not None: + return pickle.loads(val[0]) + return None + _db_tables = set() -def diskcache_put(table:str, key:Union[Dict, str, int], val:Any): - if CACHELEVEL == 0: return val - if isinstance(key, (str,int)): key = {"key": key} - conn = db_connection() - cur = conn.cursor() - if table not in _db_tables: - TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"} - ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys()) - cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))") - _db_tables.add(table) - cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) - conn.commit() - cur.close() - return val + + +def diskcache_put(table: str, key: Union[Dict, str, int], val: Any): + if CACHELEVEL == 0: + return val + if isinstance(key, (str, int)): + key = {"key": key} + conn = db_connection() + cur = conn.cursor() + if table not in _db_tables: + TYPES = { + str: "text", + bool: "integer", + int: "integer", + float: "numeric", + bytes: "blob", + } + ltypes = ", ".join(f"{k} {TYPES[type(key[k])]}" for k in key.keys()) + cur.execute( + f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))" + ) + _db_tables.add(table) + cur.execute( + f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", + tuple(key.values()) + (pickle.dumps(val),), + ) + conn.commit() + cur.close() + return val + def diskcache(func): - def wrapper(*args, **kwargs) -> bytes: - table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest() - if (ret:=diskcache_get(table, key)): return ret - return diskcache_put(table, key, func(*args, **kwargs)) - setattr(wrapper, "__wrapped__", func) - return wrapper + def wrapper(*args, **kwargs) -> bytes: + table, key = ( + f"cache_{func.__name__}", + hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest(), + ) + if ret := diskcache_get(table, key): + return ret + return diskcache_put(table, key, func(*args, **kwargs)) + + setattr(wrapper, "__wrapped__", func) + return wrapper + # *** http support *** -def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: - if url.startswith("/") or url.startswith("."): return pathlib.Path(url) - fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) - if not fp.is_file() or not allow_caching: - with request.urlopen(url, timeout=10) as r: - assert r.status == 200 - total_length = int(r.headers.get('content-length', 0)) - progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url) - (path := fp.parent).mkdir(parents=True, exist_ok=True) - with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: - while chunk := r.read(16384): progress_bar.update(f.write(chunk)) - f.close() - if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}") - pathlib.Path(f.name).rename(fp) - return fp + +def fetch( + url: str, + name: Optional[Union[pathlib.Path, str]] = None, + allow_caching=not getenv("DISABLE_HTTP_CACHE"), +) -> pathlib.Path: + if url.startswith("/") or url.startswith("."): + return pathlib.Path(url) + fp = ( + pathlib.Path(name) + if name is not None and (isinstance(name, pathlib.Path) or "/" in name) + else pathlib.Path(_cache_dir) + / "tinygrad" + / "downloads" + / (name if name else hashlib.md5(url.encode("utf-8")).hexdigest()) + ) + if not fp.is_file() or not allow_caching: + with request.urlopen(url, timeout=10) as r: + assert r.status == 200 + total_length = int(r.headers.get("content-length", 0)) + progress_bar = tqdm(total=total_length, unit="B", unit_scale=True, desc=url) + (path := fp.parent).mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: + while chunk := r.read(16384): + progress_bar.update(f.write(chunk)) + f.close() + if (file_size := os.stat(f.name).st_size) < total_length: + raise RuntimeError( + f"fetch size incomplete, {file_size} < {total_length}" + ) + pathlib.Path(f.name).rename(fp) + return fp + # *** Exec helpers + def cpu_time_execution(cb, enable): - if enable: st = time.perf_counter() - cb() - if enable: return time.perf_counter()-st + if enable: + st = time.perf_counter() + cb() + if enable: + return time.perf_counter() - st + # *** ctypes helpers -def from_mv(mv, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type)) -def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) + +def from_mv(mv, to_type=ctypes.c_char): + return ctypes.cast( + ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type) + ) + + +def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): + return (ctypes.POINTER(to_type) * len(options))( + *[ + ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) + for o in options + ] + ) + + @functools.lru_cache(maxsize=None) def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]): - class CStruct(ctypes.Structure): - _pack_, _fields_ = 1, fields - return CStruct -def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1] -def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] -def flat_mv(mv:memoryview): - if len(mv) == 0: return mv - return mv.cast("B", shape=(mv.nbytes,)) + class CStruct(ctypes.Structure): + _pack_, _fields_ = 1, fields + + return CStruct + + +def init_c_var(ctypes_var, creat_cb): + return (creat_cb(ctypes_var), ctypes_var)[1] + + +def get_bytes(arg, get_sz, get_str, check) -> bytes: + return ( + sz := init_c_var( + ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))) + ), + ctypes.string_at( + init_c_var( + ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x)) + ), + size=sz.value, + ), + )[1] + + +def flat_mv(mv: memoryview): + if len(mv) == 0: + return mv + return mv.cast("B", shape=(mv.nbytes,)) + # *** Helpers for CUDA-like APIs. + def pretty_ptx(s): - # all expressions match `` and replace it with `color()` - s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers - s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types - s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions - s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers - s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space - s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives - return s + # all expressions match `` and replace it with `color()` + s = re.sub( + r"([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])", + lambda m: m[1] + colored(m[2], "blue") + m[3], + s, + flags=re.M, + ) # identifiers + s = re.sub( + r"(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])", + lambda m: m[1] + colored(m[2], "green") + m[3], + s, + flags=re.M, + ) # types + s = re.sub( + r"^(\s*)([\w]+)(.*?;$)", + lambda m: m[1] + colored(m[2], "yellow") + m[3], + s, + flags=re.M, + ) # instructions + s = re.sub( + r"([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])", + lambda m: m[1] + colored(m[2], "yellow") + m[3], + s, + flags=re.M, + ) # numbers + s = re.sub( + r"(\.)(param|reg|global)", + lambda m: m[1] + colored(m[2], "magenta"), + s, + flags=re.M, + ) # space + s = re.sub( + r"(\.)(version|target|address_size|visible|entry)", + lambda m: m[1] + colored(m[2], "magenta"), + s, + flags=re.M, + ) # derivatives + return s -def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes: - check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "".encode(), 0, None, None)) - status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options])) - if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}") - return get_bytes(prog, get_code_size, get_code, check) +def compile_cuda_style( + prg, + compile_options, + prog_t, + create_prog, + compile_prog, + get_code, + get_code_size, + get_log, + get_log_size, + check, +) -> bytes: + check( + create_prog( + ctypes.byref(prog := prog_t()), + prg.encode(), + "".encode(), + 0, + None, + None, + ) + ) + status = compile_prog( + prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]) + ) -def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]: - c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) - return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args + if status != 0: + raise RuntimeError( + f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}" + ) + return get_bytes(prog, get_code_size, get_code, check) -def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]: - if not enable: return cb() - evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)] - evrecord(evs[0], None) - cb() - evrecord(evs[1], None) - evsync(evs[1]) - evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]) - for ev in evs: evdestroy(ev) - return ret.value * 1e-3 + +def encode_args_cuda_style( + bufs, vals, device_ptr_t, marks +) -> Tuple[ctypes.Array, ctypes.Structure]: + c_args = init_c_struct_t( + tuple( + [(f"f{i}", device_ptr_t) for i in range(len(bufs))] + + [(f"f{i}", ctypes.c_int) for i in range(len(bufs), len(bufs) + len(vals))] + ) + )(*bufs, *vals) + return (ctypes.c_void_p * 5)( + ctypes.c_void_p(marks[0]), + ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), + ctypes.c_void_p(marks[1]), + ctypes.cast( + ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p + ), + ctypes.c_void_p(marks[2]), + ), c_args + + +def time_execution_cuda_style( + cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False +) -> Optional[float]: + if not enable: + return cb() + evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)] + evrecord(evs[0], None) + cb() + evrecord(evs[1], None) + evsync(evs[1]) + evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]) + for ev in evs: + evdestroy(ev) + return ret.value * 1e-3 diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 1e2337a20..e2296df29 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -9,122 +9,245 @@ from tinygrad.shape.symbolic import Variable, NumNode, Node from weakref import ref, WeakKeyDictionary from dataclasses import dataclass + @dataclass(frozen=True) class JitItem: - prg: JITRunner # or a graph executor like MetalGraph - rawbufs: List[Optional[Buffer]] + prg: JITRunner # or a graph executor like MetalGraph + rawbufs: List[Optional[Buffer]] + def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]: - return functools.reduce(operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)), functools.reduce(operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)) -def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: - input_replace: Dict[Tuple[int, int], int] = {} - for j,ji in enumerate(jit_cache): - for i,a in enumerate(ji.rawbufs): - if a in input_rawbuffers: - input_replace[(j,i)] = input_rawbuffers.index(a) - assert len(set(input_replace.values())) == len(input_rawbuffers), "some input tensors not found" - return input_replace + return functools.reduce( + operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0) + ), functools.reduce( + operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0) + ) + + +def get_input_replace( + jit_cache: List[JitItem], input_rawbuffers: List[Buffer] +) -> Dict[Tuple[int, int], int]: + input_replace: Dict[Tuple[int, int], int] = {} + for j, ji in enumerate(jit_cache): + for i, a in enumerate(ji.rawbufs): + if a in input_rawbuffers: + input_replace[(j, i)] = input_rawbuffers.index(a) + assert len(set(input_replace.values())) == len( + input_rawbuffers + ), "some input tensors not found" + return input_replace + + def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] + return [ + j + for j, ji in enumerate(jit_cache) + if isinstance(ji.prg, CompiledASTRunner) + and ( + (ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) + or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))) + ) + ] + + def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars] + return [ + j + for j, ji in enumerate(jit_cache) + if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars + ] + + +class GraphException(Exception): + pass + + +ReturnType = TypeVar("ReturnType") -class GraphException(Exception): pass -ReturnType = TypeVar('ReturnType') class TinyJit(Generic[ReturnType]): - def __init__(self, fxn:Callable[..., ReturnType]): - self.fxn = fxn - self.reset() + def __init__(self, fxn: Callable[..., ReturnType]): + self.fxn = fxn + self.reset() - def reset(self): - self.jit_cache: List[JitItem] = [] - self.input_replace: Dict[Tuple[int, int], int] = {} - self.cnt: int = 0 - self.ret: Optional[ReturnType] = None - self.expected_vals: Optional[Tuple[Variable, ...]] = None - self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None + def reset(self): + self.jit_cache: List[JitItem] = [] + self.input_replace: Dict[Tuple[int, int], int] = {} + self.cnt: int = 0 + self.ret: Optional[ReturnType] = None + self.expected_vals: Optional[Tuple[Variable, ...]] = None + self.expected_name_sts_dtype: Optional[ + Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...] + ] = None - # add support for instance methods - def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) + # add support for instance methods + def __get__(self, obj, objtype): + return functools.partial(self.__call__, obj) - def __call__(self, *args, **kwargs) -> ReturnType: - # all inputs are realized - input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} - expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()]) + def __call__(self, *args, **kwargs) -> ReturnType: + # all inputs are realized + input_tensors: Dict[Union[int, str], Tensor] = { + cast(Union[int, str], k): v.realize() + for k, v in itertools.chain(enumerate(args), kwargs.items()) + if v.__class__ is Tensor + } + expected_name_sts_dtype = tuple( + [(k, v.lazydata.st.unbind(), v.dtype) for k, v in input_tensors.items()] + ) - # get rawbuffers - input_rawbuffers: List[Buffer] = [cast(Buffer, v.lazydata.realized) for v in input_tensors.values()] - assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" + # get rawbuffers + input_rawbuffers: List[Buffer] = [ + cast(Buffer, v.lazydata.realized) for v in input_tensors.values() + ] + assert len(set(input_rawbuffers)) == len( + input_rawbuffers + ), "duplicate inputs to JIT" - # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global - var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) - expected_vals = tuple(var_vals.keys()) + # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global + var_vals: Dict[Variable, int] = merge_dicts( + [arg.lazydata.st.var_vals for arg in input_tensors.values()] + + [ + dict( + x.unbind() + for x in itertools.chain(args, kwargs.values()) + if isinstance(x, Variable) + ) + ] + ) + expected_vals = tuple(var_vals.keys()) - if self.cnt >= 2: - # jit exec - assert self.expected_vals == expected_vals, "mismatch of var_vals" - assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" - for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] - for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True) - elif self.cnt == 1: - # jit capture - self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype - CacheCollector.start(var_vals) - self.ret = self.fxn(*args, **kwargs) - self.jit_cache = CacheCollector.finish() - assert len(self.jit_cache) != 0, "didn't JIT anything!" + if self.cnt >= 2: + # jit exec + assert self.expected_vals == expected_vals, "mismatch of var_vals" + assert ( + self.expected_name_sts_dtype == expected_name_sts_dtype + ), f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" + for (j, i), input_idx in self.input_replace.items(): + self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] + for ji in self.jit_cache: + ji.prg( + cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG >= 2, jit=True + ) + elif self.cnt == 1: + # jit capture + self.expected_vals, self.expected_name_sts_dtype = ( + expected_vals, + expected_name_sts_dtype, + ) + CacheCollector.start(var_vals) + self.ret = self.fxn(*args, **kwargs) + self.jit_cache = CacheCollector.finish() + assert len(self.jit_cache) != 0, "didn't JIT anything!" - # if your Device supports it, condense the items into a graph executor - if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2: - try: - if DEBUG >= 1: print(f"JIT GRAPHing {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") - self.jit_cache = [JitItem(make_graph(self.jit_cache, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))] - except GraphException as e: - if DEBUG >= 1: print(f"graph create failed {e}") - else: - if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") + # if your Device supports it, condense the items into a graph executor + if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2: + try: + if DEBUG >= 1: + print( + f"JIT GRAPHing {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs" + ) + self.jit_cache = [ + JitItem( + make_graph(self.jit_cache, input_rawbuffers, var_vals), + cast(List[Optional[Buffer]], input_rawbuffers), + ) + ] + except GraphException as e: + if DEBUG >= 1: + print(f"graph create failed {e}") + else: + if DEBUG >= 1: + print( + f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs" + ) - self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) - elif self.cnt == 0: - # jit ignore - self.ret = self.fxn(*args, **kwargs) + self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) + elif self.cnt == 0: + # jit ignore + self.ret = self.fxn(*args, **kwargs) - # clear jit inputs - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + # clear jit inputs + for j, i in self.input_replace.keys(): + self.jit_cache[j].rawbufs[i] = None + + self.cnt += 1 + return cast(ReturnType, self.ret) - self.cnt += 1 - return cast(ReturnType, self.ret) class PlaceHolder: - def __init__(self, buf:Buffer): self.size, self.dtype, self.device, self.ref, self.bufid = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf) - def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid) - def __hash__(self): return hash(self.to_tuple()) - def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple() - def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer: - ret = self.ref() - if ret: return ret - if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype) - return buffer_cache[self] + def __init__(self, buf: Buffer): + self.size, self.dtype, self.device, self.ref, self.bufid = ( + buf.size, + buf.dtype, + buf.device, + ref(buf), + id(buf._buf), + ) + + def to_tuple(self): + return (self.size, self.dtype, self.device, self.bufid) + + def __hash__(self): + return hash(self.to_tuple()) + + def __eq__(self, x): + return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple() + + def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer: + ret = self.ref() + if ret: + return ret + if self not in buffer_cache: + buffer_cache[self] = Buffer(self.device, self.size, self.dtype) + return buffer_cache[self] + class _CacheCollector: - def __init__(self): - self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None + def __init__(self): + self.cache: Optional[ + List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]] + ] = None - def start(self, var_vals:Optional[Dict[Variable, int]]=None): - self.cache = [] - self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary() - self.var_vals = var_vals if var_vals is not None else {} + def start(self, var_vals: Optional[Dict[Variable, int]] = None): + self.cache = [] + self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary() + self.var_vals = var_vals if var_vals is not None else {} + + def add(self, prg, rawbufs, var_vals): + if self.cache is None: + return + for k, v in var_vals.items(): + assert ( + k in self.var_vals and self.var_vals[k] == v + ), f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" + self.placeholders[rawbufs[0]] = PlaceHolder( + rawbufs[0] + ) # NOTE: this is making an assumption that 0 is special + self.cache.append( + ( + prg, + [ + self.placeholders.get(x, x) if isinstance(x, Buffer) else x + for x in rawbufs + ], + ) + ) + + def finish(self) -> List[JitItem]: + if self.cache is None: + return [] + buffer_cache: Dict[PlaceHolder, Buffer] = {} + saved_cache, self.cache = self.cache, None + return [ + JitItem( + prg, + [ + x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x + for x in pl + ], + ) + for prg, pl in saved_cache + ] - def add(self, prg, rawbufs, var_vals): - if self.cache is None: return - for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" - self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special - self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) - def finish(self) -> List[JitItem]: - if self.cache is None: return [] - buffer_cache: Dict[PlaceHolder, Buffer] = {} - saved_cache, self.cache = self.cache, None - return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache] CacheCollector = _CacheCollector() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d7e54db32..b75da11ce 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,8 +4,33 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp from weakref import ref, WeakSet, WeakValueDictionary import numpy as np -from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int, ImageDType, DEBUG -from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps, get_lazyop_info +from tinygrad.helpers import ( + prod, + getenv, + DType, + dtypes, + flatten, + dedup, + merge_dicts, + all_int, + ImageDType, + DEBUG, +) +from tinygrad.ops import ( + ScheduleItem, + UnaryOps, + BinaryOps, + TernaryOps, + ReduceOps, + MovementOps, + LoadOps, + OpType, + LazyOp, + MemBuffer, + ConstBuffer, + BufferOps, + get_lazyop_info, +) from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import Variable, sint from tinygrad.device import Buffer @@ -17,348 +42,740 @@ OPT = getenv("OPT", 2) LAZYCACHE = getenv("LAZYCACHE", 1) # TODO: movement ops that only change shape are really nops. treat them as such -REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 -MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 -PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 -PUSH_RESHAPES = OPT>=4 +( + REMOVE_MOVEMENT_NOPS, + MERGE_ELEMENTWISE_INTO_REDUCE, + SHUFFLE_MOVEMENT_OPS, + MERGE_ELEMENTWISE_OPS, +) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1) +MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2 +PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3 +PUSH_RESHAPES = OPT >= 4 # **** ast fixing functions **** -def _ast_reduceops(op:LazyOp) -> LazyOp: - # TODO: this can also corealize a binary op after the reduce, not just before - src = op.src[0] - if not src.realized: - assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp" - if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: src = src.op - return LazyOp(op.op, (src,), op.arg) + +def _ast_reduceops(op: LazyOp) -> LazyOp: + # TODO: this can also corealize a binary op after the reduce, not just before + src = op.src[0] + if not src.realized: + assert isinstance( + src.op, LazyOp + ), "if not src.realized, then src.op must be a LazyOp" + if ( + MERGE_ELEMENTWISE_INTO_REDUCE + and src.optype is BinaryOps + and len(src.children) <= 1 + ): + src = src.op + return LazyOp(op.op, (src,), op.arg) + # this supports late merging an upstream Reduce op and even an Elementwise op above that -def _ast_binaryops(op:LazyOp, shape:Tuple[sint, ...]) -> LazyOp: - real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in op.buffers} - # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how - # TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd - psrcs = [(buf,root) for buf in op.buffers if len(buf.children) <= 1 and (root:=get_movementroot_contiguous(buf)).optype == ReduceOps and not root.realized and prod(root.shape) == prod(buf.shape) and len(root.children) <= 1] - intermediate_shape = shape - if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs: - # NOTE: right now we can't handle multiple, as we'd have to check for loop - buf,root = psrcs[0] - top = _ast_reduceops(root.op) - real_srcs[buf] = top - real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified +def _ast_binaryops(op: LazyOp, shape: Tuple[sint, ...]) -> LazyOp: + real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = { + x: None for x in op.buffers + } + # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how + # TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd + psrcs = [ + (buf, root) + for buf in op.buffers + if len(buf.children) <= 1 + and (root := get_movementroot_contiguous(buf)).optype == ReduceOps + and not root.realized + and prod(root.shape) == prod(buf.shape) + and len(root.children) <= 1 + ] + intermediate_shape = shape + if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs: + # NOTE: right now we can't handle multiple, as we'd have to check for loop + buf, root = psrcs[0] + top = _ast_reduceops(root.op) + real_srcs[buf] = top + real_srcs.update( + {x: x for x in top.buffers} + ) # the reduce op buffers are not modified - # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs - if buf.shape != root.shape: - intermediate_shape = root.shape - assert buf.shape == shape, f"shape mismatch {buf.shape} != {shape}" + # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs + if buf.shape != root.shape: + intermediate_shape = root.shape + assert buf.shape == shape, f"shape mismatch {buf.shape} != {shape}" - # reshape all the late ops into the output shape - # NOTE: these RESHAPEs will return self if they don't change the shape - for buf,src in real_srcs.items(): - if src is None: real_srcs[buf] = buf.reshape(intermediate_shape) - # NOTE: cast the type to remove the Optional - ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs)) - return LazyOp(MovementOps.RESHAPE, (ast, ), shape) if intermediate_shape != shape else ast + # reshape all the late ops into the output shape + # NOTE: these RESHAPEs will return self if they don't change the shape + for buf, src in real_srcs.items(): + if src is None: + real_srcs[buf] = buf.reshape(intermediate_shape) + # NOTE: cast the type to remove the Optional + ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs)) + return ( + LazyOp(MovementOps.RESHAPE, (ast,), shape) + if intermediate_shape != shape + else ast + ) + + +def _replace_bufferops(op: LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: + replacements: Dict[LazyBuffer, LazyOp] = {} + base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()]) + for x in op.buffers: + st = x.st.simplify().unbind() + if x.base in base_bufs: + replacements[x] = LazyOp( + BufferOps.LOAD, (), MemBuffer(base_bufs.index(x.base) + 1, x.dtype, st) + ) + elif not x.realized and x.base.op.op == LoadOps.CONST: + replacements[x] = LazyOp( + BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st) + ) + else: + raise NotImplementedError(f"not handled {x}") + return ( + op.src[0] if op.op in {MovementOps.RESHAPE, LoadOps.CONTIGUOUS} else op + ).map_buffers(replacements), base_bufs -def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: - replacements:Dict[LazyBuffer, LazyOp] = {} - base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()]) - for x in op.buffers: - st = x.st.simplify().unbind() - if x.base in base_bufs: - replacements[x] = LazyOp(BufferOps.LOAD, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st)) - elif not x.realized and x.base.op.op == LoadOps.CONST: - replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st)) - else: - raise NotImplementedError(f"not handled {x}") - return (op.src[0] if op.op in {MovementOps.RESHAPE, LoadOps.CONTIGUOUS} else op).map_buffers(replacements), base_bufs # **** lazy operations **** -def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root -def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) + +def get_movementroot(root: LazyBuffer, allow_contiguous=False) -> LazyBuffer: + return ( + get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) + if not root.realized + and ( + root.optype == MovementOps + or ( + root.op.op == LoadOps.CONTIGUOUS + and allow_contiguous + and root.op.src[0].st.contiguous + ) + ) + else root + ) + + +def get_movementroot_contiguous(x: LazyBuffer) -> LazyBuffer: + return ( + get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) + if not x.realized and x.op.op == LoadOps.CONTIGUOUS + else ( + get_movementroot(x, True) + if x.optype == MovementOps and x.st.contiguous + else x + ) + ) + # NOTE: this is the canonical order -def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) +def vars_from_ast(ast: LazyOp) -> List[Variable]: + return sorted( + set.union( + *[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set() + ), + key=lambda x: str(x.expr), + ) + lazycache: WeakValueDictionary = WeakValueDictionary() -def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): - # rewrite 0 size into a CONST - if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype) - # fromcpu aren't cached - if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base) - # wop is the deduping key. i feel this used to compare more deeply - wop = (device, dtype, optype, ref(op), ref(base) if base else None) - if wop in lazycache: - for x in op.buffers: x.children.add(lazycache[wop]) - return lazycache[wop] +def create_lazybuffer( + device: str, + st: ShapeTracker, + optype: OpType, + op: LazyOp, + dtype: DType, + base: Optional[LazyBuffer] = None, +): + # rewrite 0 size into a CONST + if 0 in st.shape: + return LazyBuffer( + device, + ShapeTracker.from_shape(st.shape), + LoadOps, + LazyOp(LoadOps.CONST, tuple(), 0.0), + dtype, + ) + + # fromcpu aren't cached + if not LAZYCACHE or ( + optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST} + ): + return LazyBuffer(device, st, optype, op, dtype, base=base) + + # wop is the deduping key. i feel this used to compare more deeply + wop = (device, dtype, optype, ref(op), ref(base) if base else None) + if wop in lazycache: + for x in op.buffers: + x.children.add(lazycache[wop]) + return lazycache[wop] + + lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base) + return ret - lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base) - return ret class LazyBuffer: - __deletable__ = ('op',) - def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[Buffer]=None, base:Optional[LazyBuffer]=None): - self.device, self.st, self.shape, self.optype, self._dtype, self._realized = device, st, st.shape, optype, dtype, src - self.output_buffer: Optional[Buffer] = None # TODO: do we really need this? or can we just use realized - # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? - self.children: WeakSet[LazyBuffer] = WeakSet() - self.views: WeakSet[LazyBuffer] = WeakSet() - # NOTE: op should be read only after construction of LazyBuffer. it is now with schedule - if op is not None: - self.op = op - for x in op.buffers: x.children.add(self) - assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based" - self._base = base - if base: base.views.add(self) - else: assert st.contiguous, "unbased LazyBuffers must be contiguous" + __deletable__ = ("op",) - @property - def base(self): return self._base if self._base is not None else self + def __init__( + self, + device: str, + st: ShapeTracker, + optype: OpType, + op: Optional[LazyOp], + dtype: DType, + src: Optional[Buffer] = None, + base: Optional[LazyBuffer] = None, + ): + self.device, self.st, self.shape, self.optype, self._dtype, self._realized = ( + device, + st, + st.shape, + optype, + dtype, + src, + ) + self.output_buffer: Optional[ + Buffer + ] = None # TODO: do we really need this? or can we just use realized + # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? + self.children: WeakSet[LazyBuffer] = WeakSet() + self.views: WeakSet[LazyBuffer] = WeakSet() + # NOTE: op should be read only after construction of LazyBuffer. it is now with schedule + if op is not None: + self.op = op + for x in op.buffers: + x.children.add(self) + assert optype != MovementOps or ( + base is not None and base.optype != MovementOps + ), "MovementOps must be based" + self._base = base + if base: + base.views.add(self) + else: + assert st.contiguous, "unbased LazyBuffers must be contiguous" - def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST - def is_unrealized_contiguous_const(self): return self.is_unrealized_const() and self.st.contiguous + @property + def base(self): + return self._base if self._base is not None else self - @property - def realized(self): return self.base._realized - @realized.setter - def realized(self, val:Buffer): - assert self._base is None, "no setting realized of based LazyBuffers" - self._realized = val - @property - def dtype(self): return self.base._dtype - @dtype.setter - def dtype(self, val:DType): - assert self._base is None, "no setting dtype of based LazyBuffers" - self._dtype = val + def is_unrealized_const(self): + return not self.realized and self.base.op.op == LoadOps.CONST - def __repr__(self): return f"" + def is_unrealized_contiguous_const(self): + return self.is_unrealized_const() and self.st.contiguous - def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} + @property + def realized(self): + return self.base._realized - @property - def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) - def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self) - def get_lazyops(self) -> List[LazyOp]: return [] + @realized.setter + def realized(self, val: Buffer): + assert self._base is None, "no setting realized of based LazyBuffers" + self._realized = val - # *** scheduling *** + @property + def dtype(self): + return self.base._dtype - def schedule(self, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: - if seen is None: seen = set() - if self in seen or self.realized or self.is_unrealized_const(): return [] - seen.add(self) - if self.base is not self: return self.base.schedule(seen) + @dtype.setter + def dtype(self, val: DType): + assert self._base is None, "no setting dtype of based LazyBuffers" + self._dtype = val - op = self.op - if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape) - elif self.optype is ReduceOps: op = _ast_reduceops(op) + def __repr__(self): + return f"" - # schedule the past - ret:List[ScheduleItem] = [] - for x in op.buffers: ret += x.schedule(seen) + def _device_extra_args(self) -> Dict[str, str]: + return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} - var_vals = merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]) + @property + def buffers(self) -> Tuple[LazyBuffer, ...]: + return (self,) - op, base_bufs = _replace_bufferops(op) + def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): + return real_srcs.get(self, self) - # check if we can reuse the output buffer - # if it's aliased, don't use it - # TODO: this is pretty wrong actually, who knows where else this buffer is used? - # TODO: what if an assign is required? this silently is wrong - # NOTE: this has been moved to schedule, as this is only an issue if buffers are already realized - if self.output_buffer is not None: - for i,a in enumerate(base_bufs): - # TODO: if this is contiguous it's fine - if a.realized == self.output_buffer: - if any(not x.arg.st.contiguous for x in op.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i+1): - self.output_buffer = None - break + def get_lazyops(self) -> List[LazyOp]: + return [] - if op.op not in LoadOps: - # add the store - info = get_lazyop_info(op) - assert info.dtype == self.dtype or isinstance(self.dtype, ImageDType), f"dtype mismatch {info.dtype=} != {self.dtype=}" + # *** scheduling *** - if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): - if DEBUG >= 3: print(f"forcing image {self.dtype} to float32") - self.dtype = dtypes.float32 # NOTE; this is what makes the dtype above not match - op = LazyOp(UnaryOps.CAST, (op, ), (dtypes.float32, False)) + def schedule(self, seen: Optional[Set[LazyBuffer]] = None) -> List[ScheduleItem]: + if seen is None: + seen = set() + if self in seen or self.realized or self.is_unrealized_const(): + return [] + seen.add(self) + if self.base is not self: + return self.base.schedule(seen) - # TODO: why doesn't this match? - #assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}" - op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape))) - else: - # check loadop validity of bufferops - for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" + op = self.op + if self.optype is BinaryOps: + op = _ast_binaryops(op, self.shape) + elif self.optype is ReduceOps: + op = _ast_reduceops(op) - return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})] + # schedule the past + ret: List[ScheduleItem] = [] + for x in op.buffers: + ret += x.schedule(seen) - # *** creation/special ops *** + var_vals = merge_dicts( + [self.st.var_vals] + [buf.st.var_vals for buf in op.buffers] + ) - @staticmethod - def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker.from_shape(shape), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) + op, base_bufs = _replace_bufferops(op) - # create a constant with the shape and dtype of self - def const(self, val:Union[float, int]) -> LazyBuffer: - # NOTE: dtypes.from_np(self.dtype.np) to deal with image types - return LazyBuffer.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape) + # check if we can reuse the output buffer + # if it's aliased, don't use it + # TODO: this is pretty wrong actually, who knows where else this buffer is used? + # TODO: what if an assign is required? this silently is wrong + # NOTE: this has been moved to schedule, as this is only an issue if buffers are already realized + if self.output_buffer is not None: + for i, a in enumerate(base_bufs): + # TODO: if this is contiguous it's fine + if a.realized == self.output_buffer: + if any( + not x.arg.st.contiguous + for x in op.get_lazyops() + if x.op == BufferOps.LOAD and x.arg.idx == i + 1 + ): + self.output_buffer = None + break - def copy_to_device(self, device:str) -> LazyBuffer: - # back off a FROM if it's a double FROM - if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0]) - return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous()) + if op.op not in LoadOps: + # add the store + info = get_lazyop_info(op) + assert info.dtype == self.dtype or isinstance( + self.dtype, ImageDType + ), f"dtype mismatch {info.dtype=} != {self.dtype=}" - def contiguous(self:LazyBuffer) -> LazyBuffer: - if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST) - if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const(): - # this will turn into nothing, it's based and a copy - # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops - return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base) - return LazyBuffer.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self) + if isinstance(self.dtype, ImageDType) and ( + prod(self.shape) != prod(self.dtype.shape) + or not any(self.shape[x] % 4 == 0 for x in self.st.unit_stride_axes()) + ): + if DEBUG >= 3: + print(f"forcing image {self.dtype} to float32") + self.dtype = ( + dtypes.float32 + ) # NOTE; this is what makes the dtype above not match + op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False)) - @staticmethod - def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten())) + # TODO: why doesn't this match? + # assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}" + op = LazyOp( + BufferOps.STORE, + (op,), + MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)), + ) + else: + # check loadop validity of bufferops + for i, s in enumerate(op.src): + assert ( + isinstance(s, LazyOp) + and s.op == BufferOps.LOAD + and s.arg.idx == i + 1 + and s.arg.st.contiguous + ), f"bad LoadOps src {i}: {s}" - def cast(self, dtype:DType, bitcast:bool=False): - return self.e(UnaryOps.CAST, arg=(dtype, bitcast)) + return ret + [ + ScheduleItem( + op, self, tuple(base_bufs), {k: var_vals[k] for k in vars_from_ast(op)} + ) + ] - # *** elementwise ops *** + # *** creation/special ops *** - def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: - # srcs includes self - srcs = (self,)+srcs + @staticmethod + def loadop( + op, + shape: Tuple[sint, ...], + dtype: DType, + device: str, + arg=None, + src: Optional[LazyBuffer] = None, + ) -> LazyBuffer: + return create_lazybuffer( + device, + ShapeTracker.from_shape(shape), + LoadOps, + LazyOp(op, tuple() if src is None else (src,), arg), + dtype, + ) - # if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops - if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) + # create a constant with the shape and dtype of self + def const(self, val: Union[float, int]) -> LazyBuffer: + # NOTE: dtypes.from_np(self.dtype.np) to deal with image types + return ( + LazyBuffer.loadop( + LoadOps.CONST, + tuple(), + dtypes.from_np(self.dtype.np), + self.device, + arg=val, + ) + .reshape((1,) * len(self.shape)) + .expand(self.shape) + ) - # get outputs now - out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0] + def copy_to_device(self, device: str) -> LazyBuffer: + # back off a FROM if it's a double FROM + if ( + not self.realized + and self.op.op == LoadOps.FROM + and cast(LazyBuffer, self.op.src[0]).device == device + ): + return cast(LazyBuffer, self.op.src[0]) + return LazyBuffer.loadop( + LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous() + ) - # push all contiguous to the end of BinaryOps - if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs): - new_srcs: List[LazyBuffer] = [] - for x in srcs: - if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1: - x.op.src[0].children.discard(x) - x = cast(LazyBuffer, x.op.src[0]) + def contiguous(self: LazyBuffer) -> LazyBuffer: + if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: + return self # all LoadOps are already contiguous (except CONST) + if ( + self.st.contiguous + and self.st.size() == self.base.st.size() + and not self.is_unrealized_const() + ): + # this will turn into nothing, it's based and a copy + # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops + return create_lazybuffer( + self.device, + ShapeTracker.from_shape(tuple(self.shape)), + LoadOps, + LazyOp(LoadOps.CONTIGUOUS, (self,), None), + self.dtype, + base=self.base, + ) + return LazyBuffer.loadop( + LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self + ) + + @staticmethod + def fromCPU(x: np.ndarray) -> LazyBuffer: + return LazyBuffer( + "CPU", + ShapeTracker.from_shape(x.shape), + LoadOps, + None, + dtypes.from_np(x.dtype), + Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten()), + ) + + def cast(self, dtype: DType, bitcast: bool = False): + return self.e(UnaryOps.CAST, arg=(dtype, bitcast)) + + # *** elementwise ops *** + + def e( + self: LazyBuffer, + op: Union[UnaryOps, BinaryOps, TernaryOps], + *srcs: LazyBuffer, + arg: Optional[Any] = None, + ) -> LazyBuffer: + # srcs includes self + srcs = (self,) + srcs + + # if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops + if SHUFFLE_MOVEMENT_OPS: + srcs = _push_movement_ops(srcs) + + # get outputs now + out_device, out_shape, out_dtype = ( + srcs[0].device, + srcs[0].shape, + max([x.dtype for x in srcs]) + if op != UnaryOps.CAST + else cast(Tuple[DType, bool], arg)[0], + ) + + # push all contiguous to the end of BinaryOps + if PUSH_CONTIGUOUS and any( + not x.realized + and x.op.op == LoadOps.CONTIGUOUS + and len(x.op.src[0].children) <= 1 + for x in srcs + ): + new_srcs: List[LazyBuffer] = [] + for x in srcs: + if ( + not x.realized + and x.op.op == LoadOps.CONTIGUOUS + and len(x.op.src[0].children) <= 1 + ): + x.op.src[0].children.discard(x) + x = cast(LazyBuffer, x.op.src[0]) + new_srcs.append(x) + return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous() + + if MERGE_ELEMENTWISE_OPS: + # remove the buffers from any (childless) BinaryOps that feed into this + _srcs = tuple( + [ + x.op + if x.optype == BinaryOps and not x.children and not x.realized + else x + for x in srcs + ] + ) + # TODO: needs general merge limiting + if ( + out_device != "WEBGPU" + or len( + dedup( + [ + x.base + for _src in _srcs + for x in _src.buffers + if not x.is_unrealized_const() + ] + ) + ) + < 7 + ): + srcs = _srcs # type: ignore + + return create_lazybuffer( + out_device, + ShapeTracker.from_shape(out_shape), + BinaryOps, + LazyOp(op, srcs, arg), + out_dtype, + ) + + # *** reduce ops *** + + def _reduce_op( + self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...] + ) -> LazyBuffer: + if self.shape == tuple(new_shape): + return self + srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) + unbound_new_shape = tuple( + s.unbind()[0] if not isinstance(s, int) else s for s in new_shape + ) + return create_lazybuffer( + self.device, + ShapeTracker.from_shape(new_shape), + ReduceOps, + LazyOp(op, srcs, unbound_new_shape), + self.dtype, + ) + + def r(self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]) -> LazyBuffer: + # TODO: can we split symbolic shape if the reduce axis is not symbolic? + if ( + not all_int(self.shape) + or (0 in self.shape) + or prod(self.shape) // prod(new_shape) + < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768) + ): + return self._reduce_op(op, new_shape) + heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old)) / (stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore + if divisor < 16 or heuristic < 0.1: + return self._reduce_op(op, new_shape) + + # choose largest divisor (>=16) to split on, penalize large strides + def splitted_shape(dim_aft_div): + return ( + self.shape[:dim_to_split] + + (self.shape[dim_to_split] // divisor,) + + dim_aft_div + + self.shape[dim_to_split + 1 :] + ) + + return ( + self.reshape(splitted_shape((divisor,))) + ._reduce_op(op, splitted_shape((1,))) + .reshape(splitted_shape(())) + ._reduce_op(op, new_shape) + ) + + # *** movement ops *** + + def reshape(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer: + if self.shape == arg: + return self + if not self.realized and self.op.op == MovementOps.RESHAPE: + assert isinstance(self.op.src[0], LazyBuffer) + self.op.src[0].children.discard( + self + ) # NOTE: this is only required in reshape and when pushing permutes, why?? + return self.op.src[0].reshape(arg) + return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg) + + def pad(self: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer: + if all(b == 0 and e == 0 for b, e in arg): + return self + if not self.realized and self.op.op == MovementOps.PAD: + return self.op.src[0].pad( + tuple( + [(b1 + b2, e1 + e2) for (b1, e1), (b2, e2) in zip(self.op.arg, arg)] + ) + ) + return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg) + + def expand(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer: + if self.shape == arg: + return self + if not self.realized and self.op.op == MovementOps.EXPAND: + return self.op.src[0].expand(arg) + return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg) + + def permute(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer: + if arg == tuple(range(len(self.shape))): + return self + if not self.realized and self.op.op == MovementOps.PERMUTE: + return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg])) + if SHUFFLE_MOVEMENT_OPS and not self.realized: + if PUSH_PERMUTES and self.optype == ReduceOps: + # reduceops have one buffer input, permute it + narg = tuple([self.op.arg[a] for a in arg]) + src, rop = self.op.src[0], self.op.op + src.children.discard(self) + del self # TODO: why doesn't this delete remove it from the children + return src.permute(arg).r(cast(ReduceOps, rop), narg) + + # move permutes before expands (always, this is safe) + if self.op.op == MovementOps.EXPAND: + return ( + self.op.src[0] + .permute(arg) + .expand(tuple([self.op.arg[a] for a in arg])) + ) + + # move permutes before reshapes if we can + if ( + PUSH_PERMUTES + and self.op.op == MovementOps.RESHAPE + and isinstance(self.op.src[0], LazyBuffer) + ): + if shape_idx_groups := get_contraction( + self.op.src[0].shape, self.shape + ): + self.op.src[0].children.discard( + self + ) # NOTE: this is only required in reshape and when pushing permutes, why?? + return ( + self.op.src[0] + .permute(tuple(flatten(shape_idx_groups[i] for i in arg))) + .reshape(self.st.permute(arg).shape) + ) + return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg) + + def shrink(self: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: + if all(b - a == s for s, (a, b) in zip(self.shape, arg)): + return self + if not self.realized and self.op.op == MovementOps.SHRINK: + return self.op.src[0].shrink( + tuple( + [(b1 + b2, b1 + e2) for (b1, _), (b2, e2) in zip(self.op.arg, arg)] + ) + ) + return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg) + + def stride(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer: + if all(a == 1 for a in arg): + return self + if not self.realized and self.op.op == MovementOps.STRIDE: + return self.op.src[0].stride( + tuple(a1 * a2 for a1, a2 in zip(arg, self.op.arg)) + ) + return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg) + + def _movement_op( + self, + st: ShapeTracker, + op: MovementOps, + arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]], + ) -> LazyBuffer: + if ( + SHUFFLE_MOVEMENT_OPS + and not self.realized + and self.optype == BinaryOps + and not self.children + ): + if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or ( + op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES) + ): + return self.op.replace_with_movement_ops([(op, arg)]) + if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: + # MovementOps aren't stacked any more, they each have one parent, find the root + if ( + (root := get_movementroot(self)) != self + and root.st.contiguous + and prod(st.shape) == prod(root.shape) + ): + return root.reshape(st.shape) + return create_lazybuffer( + self.device, + st, + MovementOps, + LazyOp(op, (self,), arg), + self.dtype, + base=self.base, + ) + + def replace_with_movement_ops( + self: LazyBuffer, ops: List[Tuple[MovementOps, Any]] + ) -> LazyBuffer: + y = self + for op, arg in ops: + y = MOVEMENT_OPS_DISPATCHER[op](y, arg) + return y + + +UNSAFE_PAD_OPS = { + BinaryOps.DIV, + BinaryOps.CMPLT, + UnaryOps.LOG2, + UnaryOps.EXP2, + UnaryOps.RECIP, +} + + +def _push_movement_ops(srcs: Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: + new_srcs = [] + for x in srcs: + mops: List[Tuple[MovementOps, Any]] = [] + bx = x + # backwalk all the movement ops. don't push PAD or EXPAND + while ( + not bx.realized + and bx.optype is MovementOps + and bx.op.op is not MovementOps.EXPAND + and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) + and len(bx.children) <= 1 + ): + assert isinstance(bx.op.op, MovementOps) and isinstance( + bx.op.src[0], LazyBuffer + ) + mops.append((bx.op.op, bx.op.arg)) + bx = bx.op.src[0] + # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 + if ( + mops + and not bx.realized + and bx.optype is BinaryOps + and len(bx.children) <= 1 + and ( + all(y[0] is not MovementOps.PAD for y in mops) + or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops()) + ) + ): + x = bx.op.replace_with_movement_ops(mops[::-1]) new_srcs.append(x) - return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous() + return tuple(new_srcs) - if MERGE_ELEMENTWISE_OPS: - # remove the buffers from any (childless) BinaryOps that feed into this - _srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) - # TODO: needs general merge limiting - if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore - - return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype) - - # *** reduce ops *** - - def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: - if self.shape == tuple(new_shape): return self - srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) - unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape) - return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype) - - def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: - # TODO: can we split symbolic shape if the reduce axis is not symbolic? - if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) - heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore - if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) - # choose largest divisor (>=16) to split on, penalize large strides - def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] - return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape) - - # *** movement ops *** - - def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: - if self.shape == arg: return self - if not self.realized and self.op.op == MovementOps.RESHAPE: - assert isinstance(self.op.src[0], LazyBuffer) - self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? - return self.op.src[0].reshape(arg) - return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg) - - def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: - if all(b == 0 and e == 0 for b,e in arg): return self - if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) - return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg) - - def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: - if self.shape == arg: return self - if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg) - return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg) - - def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: - if arg == tuple(range(len(self.shape))): return self - if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg])) - if SHUFFLE_MOVEMENT_OPS and not self.realized: - if PUSH_PERMUTES and self.optype == ReduceOps: - # reduceops have one buffer input, permute it - narg = tuple([self.op.arg[a] for a in arg]) - src, rop = self.op.src[0], self.op.op - src.children.discard(self) - del self # TODO: why doesn't this delete remove it from the children - return src.permute(arg).r(cast(ReduceOps, rop), narg) - - # move permutes before expands (always, this is safe) - if self.op.op == MovementOps.EXPAND: - return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg])) - - # move permutes before reshapes if we can - if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer): - if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): - self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? - return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape) - return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg) - - def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: - if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self - if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) - return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg) - - def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: - if all(a == 1 for a in arg): return self - if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg))) - return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg) - - def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: - if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children: - if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)): - return self.op.replace_with_movement_ops([(op, arg)]) - if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: - # MovementOps aren't stacked any more, they each have one parent, find the root - if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape): - return root.reshape(st.shape) - return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base) - - def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: - y = self - for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) - return y - -UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} - -def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: - new_srcs = [] - for x in srcs: - mops: List[Tuple[MovementOps, Any]] = [] - bx = x - # backwalk all the movement ops. don't push PAD or EXPAND - while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1: - assert isinstance(bx.op.op, MovementOps) and isinstance(bx.op.src[0], LazyBuffer) - mops.append((bx.op.op, bx.op.arg)) - bx = bx.op.src[0] - # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 - if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())): - x = bx.op.replace_with_movement_ops(mops[::-1]) - new_srcs.append(x) - return tuple(new_srcs) MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { - MovementOps.RESHAPE: LazyBuffer.reshape, MovementOps.EXPAND: LazyBuffer.expand, MovementOps.SHRINK: LazyBuffer.shrink, - MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride, + MovementOps.RESHAPE: LazyBuffer.reshape, + MovementOps.EXPAND: LazyBuffer.expand, + MovementOps.SHRINK: LazyBuffer.shrink, + MovementOps.PERMUTE: LazyBuffer.permute, + MovementOps.PAD: LazyBuffer.pad, + MovementOps.STRIDE: LazyBuffer.stride, } diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index ee0766fbe..b113efb64 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -6,206 +6,299 @@ from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer from tinygrad.shape.symbolic import sint + class Contiguous(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous() - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output + def forward(self, x: LazyBuffer) -> LazyBuffer: + return x.contiguous() + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output + class ContiguousBackward(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: return x - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous() + def forward(self, x: LazyBuffer) -> LazyBuffer: + return x + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.contiguous() + class Cast(Function): - def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer: - self.input_dtype, self.bitcast = x.dtype, bitcast - return x.cast(dtype, bitcast) + def forward(self, x: LazyBuffer, dtype: DType, bitcast: bool = False) -> LazyBuffer: + self.input_dtype, self.bitcast = x.dtype, bitcast + return x.cast(dtype, bitcast) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.cast(self.input_dtype, self.bitcast) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.cast(self.input_dtype, self.bitcast) # ************* unary ops ************* + class Zero(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0) - def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.const(0) + def forward(self, x: LazyBuffer) -> LazyBuffer: + return x.const(0) + + def backward(self, grad: LazyBuffer) -> LazyBuffer: + return grad.const(0) + class Neg(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG) - def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.e(UnaryOps.NEG) + def forward(self, x: LazyBuffer) -> LazyBuffer: + return x.e(UnaryOps.NEG) + + def backward(self, grad: LazyBuffer) -> LazyBuffer: + return grad.e(UnaryOps.NEG) + class Sin(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.x = x - return x.e(UnaryOps.SIN) + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.x = x + return x.e(UnaryOps.SIN) + + def backward(self, grad: LazyBuffer) -> LazyBuffer: + return ( + self.x.const(math.pi / 2) + .e(BinaryOps.SUB, self.x) + .e(UnaryOps.SIN) + .e(BinaryOps.MUL, grad) + ) - def backward(self, grad:LazyBuffer) -> LazyBuffer: - return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad) # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.e(BinaryOps.MAX, x.const(0)) - return self.ret + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.ret = x.e(BinaryOps.MAX, x.const(0)) + return self.ret + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return ( + self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output) + ) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output) class Log(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.x = x - return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.x = x + return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.e(BinaryOps.DIV, self.x) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.e(BinaryOps.DIV, self.x) class Exp(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) - return self.ret + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.ret = x.e(BinaryOps.MUL, x.const(1 / math.log(2))).e(UnaryOps.EXP2) + return self.ret + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return self.ret.e(BinaryOps.MUL, grad_output) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.e(BinaryOps.MUL, grad_output) class Sqrt(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.e(UnaryOps.SQRT) - return self.ret + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.ret = x.e(UnaryOps.SQRT) + return self.ret + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.e( + BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)) + ) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2))) # NOTE: the implicit derivative of sigmoid is not stable # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e # TODO: have the backend automatically find this class Sigmoid(Function): - def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2))) - return self.ret + def forward(self, x: LazyBuffer) -> LazyBuffer: + self.ret = x.const(1).e( + BinaryOps.DIV, + x.const(1).e( + BinaryOps.ADD, + x.e(BinaryOps.MUL, x.const(-1 / math.log(2))).e(UnaryOps.EXP2), + ), + ) + return self.ret + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return self.ret.e( + BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret) + ).e(BinaryOps.MUL, grad_output) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output) # ************* binary ops ************* + class Less(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.CMPLT, y) + def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.CMPLT, y) + class Add(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.ADD, y) + def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.ADD, y) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + return ( + grad_output if self.needs_input_grad[0] else None, + grad_output if self.needs_input_grad[1] else None, + ) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output if self.needs_input_grad[0] else None, \ - grad_output if self.needs_input_grad[1] else None class Sub(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.SUB, y) + def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer: + return x.e(BinaryOps.SUB, y) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + return ( + grad_output if self.needs_input_grad[0] else None, + grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None, + ) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output if self.needs_input_grad[0] else None, \ - grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None class Mul(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y = x, y - return x.e(BinaryOps.MUL, y) + def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer: + self.x, self.y = x, y + return x.e(BinaryOps.MUL, y) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + return ( + self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, + self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None, + ) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \ - self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None class Div(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - self.x, self.y = x, y - return x.e(BinaryOps.DIV, y) + def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer: + self.x, self.y = x, y + return x.e(BinaryOps.DIV, y) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + return ( + grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, + grad_output.e(UnaryOps.NEG) + .e(BinaryOps.MUL, self.x) + .e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) + if self.needs_input_grad[1] + else None, + ) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ - grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # ************* ternary ops ************* -class Where(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: - self.x = x - return x.e(TernaryOps.WHERE, y, z) - def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: - return None, \ - self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ - self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None +class Where(Function): + def forward(self, x: LazyBuffer, y: LazyBuffer, z: LazyBuffer) -> LazyBuffer: + self.x = x + return x.e(TernaryOps.WHERE, y, z) + + def backward( + self, grad_output: LazyBuffer + ) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: + return ( + None, + self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) + if self.needs_input_grad[1] + else None, + self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) + if self.needs_input_grad[2] + else None, + ) + # ************* reduce ops ************* -class Sum(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape - return x.r(ReduceOps.SUM, new_shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.expand(self.input_shape) +class Sum(Function): + def forward(self, x: LazyBuffer, new_shape: Tuple[int, ...]) -> LazyBuffer: + self.input_shape = x.shape + return x.r(ReduceOps.SUM, new_shape) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.expand(self.input_shape) + class Max(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) - return self.ret + def forward(self, x: LazyBuffer, new_shape: Tuple[int, ...]) -> LazyBuffer: + self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) + return self.ret + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + # 1s in locations where the max was chosen (can be two locations) + max_is_1s = self.x.const(1.0).e( + BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)) + ) + div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + return max_is_1s.e(BinaryOps.DIV, div).e( + BinaryOps.MUL, grad_output.expand(self.x.shape) + ) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) - div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) - return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* movement ops ************* + # NOTE: this is sum in reverse class Expand(Function): - def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape - return x.expand(shape) + def forward(self, x: LazyBuffer, shape: Tuple[int, ...]) -> LazyBuffer: + self.input_shape = x.shape + return x.expand(shape) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.r(ReduceOps.SUM, self.input_shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.r(ReduceOps.SUM, self.input_shape) class Reshape(Function): - def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape - return x.reshape(shape) + def forward(self, x: LazyBuffer, shape: Tuple[int, ...]) -> LazyBuffer: + self.input_shape = x.shape + return x.reshape(shape) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.reshape(self.input_shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.reshape(self.input_shape) class Permute(Function): - def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: - self.input_order = order - return x.permute(order) + def forward(self, x: LazyBuffer, order: Tuple[int, ...]) -> LazyBuffer: + self.input_order = order + return x.permute(order) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.permute(argsort(self.input_order)) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.permute(argsort(self.input_order)) class Pad(Function): - def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: - self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) - return x.pad(arg) + def forward(self, x: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer: + self.narg = tuple([(p[0], s + p[0]) for s, p in zip(x.shape, arg)]) + return x.pad(arg) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.shrink(self.narg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.shrink(self.narg) class Shrink(Function): - def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: - self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) - return x.shrink(arg) + def forward(self, x: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: + self.narg = tuple([(p[0], s - p[1]) for s, p in zip(x.shape, arg)]) + return x.shrink(arg) + + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + assert all( + isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg + ), "symbolic shrink does not support backward" + # need this cast because mypy cannot narrow the type even with assert + return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward" - # need this cast because mypy cannot narrow the type even with assert - return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) class Flip(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: - self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) - return x.stride(self.arg) + def forward(self, x: LazyBuffer, axis: Tuple[int, ...]) -> LazyBuffer: + self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) + return x.stride(self.arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.stride(self.arg) + def backward(self, grad_output: LazyBuffer) -> LazyBuffer: + return grad_output.stride(self.arg) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 79eda9681..b566434ad 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -4,126 +4,305 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import prod, all_int from tinygrad.nn import optim, state # noqa: F401 + class BatchNorm2d: - def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): - self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum + def __init__( + self, sz: int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1 + ): + self.eps, self.track_running_stats, self.momentum = ( + eps, + track_running_stats, + momentum, + ) - if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) - else: self.weight, self.bias = None, None + if affine: + self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) + else: + self.weight, self.bias = None, None - self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) - self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) + self.running_mean, self.running_var = Tensor.zeros( + sz, requires_grad=False + ), Tensor.ones(sz, requires_grad=False) + self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) - def __call__(self, x:Tensor): - if Tensor.training: - # This requires two full memory accesses to x - # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh - # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm - batch_mean = x.mean(axis=(0,2,3)) - y = (x - batch_mean.reshape(shape=[1, -1, 1, 1])) - batch_var = (y*y).mean(axis=(0,2,3)) - batch_invstd = batch_var.add(self.eps).pow(-0.5) + def __call__(self, x: Tensor): + if Tensor.training: + # This requires two full memory accesses to x + # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh + # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + batch_mean = x.mean(axis=(0, 2, 3)) + y = x - batch_mean.reshape(shape=[1, -1, 1, 1]) + batch_var = (y * y).mean(axis=(0, 2, 3)) + batch_invstd = batch_var.add(self.eps).pow(-0.5) - # NOTE: wow, this is done all throughout training in most PyTorch models - if self.track_running_stats: - self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach()) - self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape) - y.shape[1]) * batch_var.detach() ) - self.num_batches_tracked += 1 - else: - batch_mean = self.running_mean - # NOTE: this can be precomputed for static inference. we expand it here so it fuses - batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt() + # NOTE: wow, this is done all throughout training in most PyTorch models + if self.track_running_stats: + self.running_mean.assign( + (1 - self.momentum) * self.running_mean + + self.momentum * batch_mean.detach() + ) + self.running_var.assign( + (1 - self.momentum) * self.running_var + + self.momentum + * prod(y.shape) + / (prod(y.shape) - y.shape[1]) + * batch_var.detach() + ) + self.num_batches_tracked += 1 + else: + batch_mean = self.running_mean + # NOTE: this can be precomputed for static inference. we expand it here so it fuses + batch_invstd = ( + self.running_var.reshape(1, -1, 1, 1) + .expand(x.shape) + .add(self.eps) + .rsqrt() + ) + + return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd) - return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd) # TODO: these Conv lines are terrible -def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): - return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias) +def Conv1d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, +): + return Conv2d( + in_channels, + out_channels, + (kernel_size,), + stride, + padding, + dilation, + groups, + bias, + ) + class Conv2d: - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): - self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) - self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups - self.weight = self.initialize_weight(out_channels, in_channels, groups) - assert all_int(self.weight.shape), "does not support symbolic shape" - bound = 1 / math.sqrt(prod(self.weight.shape[1:])) - self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + self.kernel_size = ( + (kernel_size, kernel_size) + if isinstance(kernel_size, int) + else tuple(kernel_size) + ) + self.stride, self.padding, self.dilation, self.groups = ( + stride, + padding, + dilation, + groups, + ) + self.weight = self.initialize_weight(out_channels, in_channels, groups) + assert all_int(self.weight.shape), "does not support symbolic shape" + bound = 1 / math.sqrt(prod(self.weight.shape[1:])) + self.bias = ( + Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None + ) - def __call__(self, x:Tensor): - return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) + def __call__(self, x: Tensor): + return x.conv2d( + self.weight, + self.bias, + padding=self.padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + ) - def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) + def initialize_weight(self, out_channels, in_channels, groups): + return Tensor.kaiming_uniform( + out_channels, in_channels // groups, *self.kernel_size, a=math.sqrt(5) + ) + + +def ConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + bias=True, +): + return ConvTranspose2d( + in_channels, + out_channels, + (kernel_size,), + stride, + padding, + output_padding, + dilation, + groups, + bias, + ) -def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): - return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias) class ConvTranspose2d(Conv2d): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): - super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) - self.output_padding = output_padding + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + bias=True, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.output_padding = output_padding - def __call__(self, x:Tensor): - return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) + def __call__(self, x: Tensor): + return x.conv_transpose2d( + self.weight, + self.bias, + padding=self.padding, + output_padding=self.output_padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + ) + + def initialize_weight(self, out_channels, in_channels, groups): + return Tensor.kaiming_uniform( + in_channels, out_channels // groups, *self.kernel_size, a=math.sqrt(5) + ) - def initialize_weight(self, out_channels, in_channels, groups): return Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) class Linear: - def __init__(self, in_features, out_features, bias=True): - self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5)) - # TODO: remove this once we can represent Tensor with int shape in typing - assert isinstance(self.weight.shape[1], int), "does not support symbolic shape" - bound = 1 / math.sqrt(self.weight.shape[1]) - self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None + def __init__(self, in_features, out_features, bias=True): + self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5)) + # TODO: remove this once we can represent Tensor with int shape in typing + assert isinstance(self.weight.shape[1], int), "does not support symbolic shape" + bound = 1 / math.sqrt(self.weight.shape[1]) + self.bias = ( + Tensor.uniform(out_features, low=-bound, high=bound) if bias else None + ) + + def __call__(self, x: Tensor): + return x.linear(self.weight.transpose(), self.bias) - def __call__(self, x:Tensor): - return x.linear(self.weight.transpose(), self.bias) class GroupNorm: - def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True): - self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps - self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None - self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None + def __init__( + self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True + ): + self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps + self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None + self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None - def __call__(self, x:Tensor): - # reshape for layernorm to work as group norm - # subtract mean and divide stddev - x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape) + def __call__(self, x: Tensor): + # reshape for layernorm to work as group norm + # subtract mean and divide stddev + x = ( + x.reshape(x.shape[0], self.num_groups, -1) + .layernorm(eps=self.eps) + .reshape(x.shape) + ) + + if self.weight is None or self.bias is None: + return x + # elementwise_affine on channels + return x * self.weight.reshape( + 1, -1, *[1] * (len(x.shape) - 2) + ) + self.bias.reshape(1, -1, *[1] * (len(x.shape) - 2)) - if self.weight is None or self.bias is None: return x - # elementwise_affine on channels - return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class InstanceNorm: - def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True): - self.num_features, self.eps = num_features, eps - self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None - self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None + def __init__(self, num_features: int, eps: float = 1e-5, affine: bool = True): + self.num_features, self.eps = num_features, eps + self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None + self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None + + def __call__(self, x: Tensor): + x = ( + x.reshape(x.shape[0], self.num_features, -1) + .layernorm(eps=self.eps) + .reshape(x.shape) + ) + if self.weight is None or self.bias is None: + return x + return x * self.weight.reshape( + 1, -1, *[1] * (len(x.shape) - 2) + ) + self.bias.reshape(1, -1, *[1] * (len(x.shape) - 2)) - def __call__(self, x:Tensor): - x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape) - if self.weight is None or self.bias is None: return x - return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2)) class LayerNorm: - def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): - self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape) - self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine - self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None) + def __init__( + self, + normalized_shape: Union[int, Tuple[int, ...]], + eps: float = 1e-5, + elementwise_affine: bool = True, + ): + self.normalized_shape = ( + (normalized_shape,) + if isinstance(normalized_shape, int) + else tuple(normalized_shape) + ) + self.axis, self.eps, self.elementwise_affine = ( + tuple(-1 - i for i in range(len(self.normalized_shape))), + eps, + elementwise_affine, + ) + self.weight, self.bias = ( + (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) + if elementwise_affine + else (None, None) + ) + + def __call__(self, x: Tensor): + assert ( + self.normalized_shape == x.shape[-len(self.normalized_shape) :] + ), f"last dimensions of {x.shape} must match {self.normalized_shape}" + x = x.layernorm(eps=self.eps, axis=self.axis) + if not self.elementwise_affine: + return x + return x * self.weight + self.bias - def __call__(self, x:Tensor): - assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}" - x = x.layernorm(eps=self.eps, axis=self.axis) - if not self.elementwise_affine: return x - return x * self.weight + self.bias class LayerNorm2d(LayerNorm): - def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + def __call__(self, x): + return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + class Embedding: - def __init__(self, vocab_size:int, embed_size:int): - self.vocab_size = vocab_size - self.weight = Tensor.glorot_uniform(vocab_size, embed_size) + def __init__(self, vocab_size: int, embed_size: int): + self.vocab_size = vocab_size + self.weight = Tensor.glorot_uniform(vocab_size, embed_size) - def __call__(self, idx:Tensor) -> Tensor: - if not hasattr(self, 'vocab_counter'): self.vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size) - return (self.vocab_counter == idx.unsqueeze(2)).expand(*idx.shape, self.vocab_size) @ self.weight + def __call__(self, idx: Tensor) -> Tensor: + if not hasattr(self, "vocab_counter"): + self.vocab_counter = Tensor.arange( + self.vocab_size, requires_grad=False + ).reshape(1, 1, self.vocab_size) + return (self.vocab_counter == idx.unsqueeze(2)).expand( + *idx.shape, self.vocab_size + ) @ self.weight diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 52acb4627..11a3c63bc 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -3,68 +3,122 @@ from typing import List from tinygrad.helpers import dedup from tinygrad.tensor import Tensor + class Optimizer: - def __init__(self, params: List[Tensor], lr: float): - # if it's None, but being put into an optimizer, set it to True - for x in params: - if x.requires_grad is None: x.requires_grad = True + def __init__(self, params: List[Tensor], lr: float): + # if it's None, but being put into an optimizer, set it to True + for x in params: + if x.requires_grad is None: + x.requires_grad = True - self.params: List[Tensor] = dedup([x for x in params if x.requires_grad]) - assert len(self.params) != 0, "optimizer must have at least one param" - self.device = self.params[0].device - self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized - self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous() + self.params: List[Tensor] = dedup([x for x in params if x.requires_grad]) + assert len(self.params) != 0, "optimizer must have at least one param" + self.device = self.params[0].device + self.buffers: List[Tensor] = dedup( + [x for x in params if not x.requires_grad] + ) # buffers are still realized + self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous() - def zero_grad(self): - for param in self.params: param.grad = None + def zero_grad(self): + for param in self.params: + param.grad = None + + def realize(self, extra=None): + # NOTE: in extra is too late for most of the params due to issues with assign + Tensor.corealize( + extra + self.params + self.buffers + if extra is not None + else self.params + self.buffers + ) - def realize(self, extra=None): - # NOTE: in extra is too late for most of the params due to issues with assign - Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers) class SGD(Optimizer): - def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False): - super().__init__(params, lr) - self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov - self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else [] + def __init__( + self, + params: List[Tensor], + lr=0.001, + momentum=0, + weight_decay=0.0, + nesterov=False, + ): + super().__init__(params, lr) + self.momentum, self.wd, self.nesterov = momentum, weight_decay, nesterov + self.b = ( + [ + Tensor.zeros(*t.shape, device=t.device, requires_grad=False) + for t in self.params + ] + if self.momentum + else [] + ) + + # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html + def step(self) -> None: + for i, t in enumerate(self.params): + assert t.grad is not None + g = t.grad.realize() + self.wd * t.detach() + if self.momentum: + self.b[i].assign( + self.momentum * self.b[i] + g + ).realize() # NOTE: self.b[i] is zero on the first run, no if required + g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] + t.assign(t.detach() - g * self.lr) + self.realize(self.b) - # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html - def step(self) -> None: - for i, t in enumerate(self.params): - assert t.grad is not None - g = t.grad.realize() + self.wd * t.detach() - if self.momentum: - self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required - g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] - t.assign(t.detach() - g * self.lr) - self.realize(self.b) # LAMB is essentially just the trust ratio part of LARS applied to Adam/W so if we just set the trust ratio to 1.0 its just Adam/W. -def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): return LAMB(params, lr, b1, b2, eps, wd, adam=True) -def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAMB(params, lr, b1, b2, eps, 0.0, adam=True) +def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): + return LAMB(params, lr, b1, b2, eps, wd, adam=True) + + +def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): + return LAMB(params, lr, b1, b2, eps, 0.0, adam=True) + class LAMB(Optimizer): - def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): - super().__init__(params, lr) - self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize() - self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] - self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] + def __init__( + self, + params: List[Tensor], + lr=0.001, + b1=0.9, + b2=0.999, + eps=1e-6, + wd=0.0, + adam=False, + ): + super().__init__(params, lr) + self.b1, self.b2, self.eps, self.wd, self.adam, self.t = ( + b1, + b2, + eps, + wd, + adam, + Tensor([0], requires_grad=False).realize(), + ) + self.m = [ + Tensor.zeros(*t.shape, device=t.device, requires_grad=False) + for t in self.params + ] + self.v = [ + Tensor.zeros(*t.shape, device=t.device, requires_grad=False) + for t in self.params + ] - def step(self) -> None: - self.t.assign(self.t + 1).realize() - for i, t in enumerate(self.params): - assert t.grad is not None - g = t.grad.realize() - self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() - self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() - m_hat = self.m[i] / (1.0 - self.b1**self.t) - v_hat = self.v[i] / (1.0 - self.b2**self.t) - up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() - if not self.adam: - r1 = t.detach().square().sum().sqrt() - r2 = up.square().sum().sqrt() - r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) - else: - r = 1.0 - t.assign(t.detach() - self.lr * r * up) - self.realize([self.t] + self.m + self.v) + def step(self) -> None: + self.t.assign(self.t + 1).realize() + for i, t in enumerate(self.params): + assert t.grad is not None + g = t.grad.realize() + self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() + self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() + m_hat = self.m[i] / (1.0 - self.b1**self.t) + v_hat = self.v[i] / (1.0 - self.b2**self.t) + up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() + if not self.adam: + r1 = t.detach().square().sum().sqrt() + r2 = up.square().sum().sqrt() + r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) + else: + r = 1.0 + t.assign(t.detach() - self.lr * r * up) + self.realize([self.t] + self.m + self.v) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index a1adf348a..95cfab4cb 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -2,142 +2,288 @@ import os, json, pathlib, zipfile, pickle, tarfile, struct from tqdm import tqdm from typing import Dict, Union, List, Optional, Any, Tuple from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap +from tinygrad.helpers import ( + dtypes, + prod, + argsort, + DEBUG, + Timing, + GlobalCounters, + CI, + unwrap, +) from tinygrad.shape.view import strides_for_shape from tinygrad import Device -safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64} -inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} +safe_dtypes = { + "F16": dtypes.float16, + "F32": dtypes.float32, + "U8": dtypes.uint8, + "I8": dtypes.int8, + "I32": dtypes.int32, + "I64": dtypes.int64, +} +inverse_safe_dtypes = {v: k for k, v in safe_dtypes.items()} -def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: - t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") - json_len = t[0:1].cast(dtypes.int64).numpy()[0] - return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())) -def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]: - t, json_len, metadata = safe_load_metadata(fn) - return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"} +def safe_load_metadata(fn: Union[Tensor, str]) -> Tuple[Tensor, int, Any]: + t = ( + fn + if isinstance(fn, Tensor) + else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") + ) + json_len = t[0:1].cast(dtypes.int64).numpy()[0] + return (t, json_len, json.loads(t[8 : 8 + json_len].numpy().tobytes())) + + +def safe_load(fn: Union[Tensor, str]) -> Dict[str, Tensor]: + t, json_len, metadata = safe_load_metadata(fn) + return { + k: t[8 + json_len + v["data_offsets"][0] :] + .cast(safe_dtypes[v["dtype"]])[: prod(v["shape"])] + .reshape(v["shape"]) + for k, v in metadata.items() + if k != "__metadata__" + } + + +def safe_save( + tensors: Dict[str, Tensor], fn: str, metadata: Optional[Dict[str, Any]] = None +): + headers, offset = {}, 0 + if metadata: + headers["__metadata__"] = metadata + for k, v in tensors.items(): + headers[k] = { + "dtype": inverse_safe_dtypes[v.dtype], + "shape": list(v.shape), + "data_offsets": [offset, offset + v.nbytes()], + } + offset += v.nbytes() + j = json.dumps(headers, separators=(",", ":")) + j += "\x20" * ((8 - len(j) % 8) % 8) + pathlib.Path(fn).unlink(missing_ok=True) + t = Tensor.empty(8 + len(j) + offset, dtype=dtypes.uint8, device=f"disk:{fn}") + t[0:1].cast(dtypes.int64).assign([len(j)]) + t[8 : 8 + len(j)].assign( + Tensor(list(j.encode("utf-8")), dtype=dtypes.uint8, device="cpu") + ) + for k, v in safe_load(t).items(): + v.assign(tensors[k]) -def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None): - headers, offset = {}, 0 - if metadata: headers['__metadata__'] = metadata - for k,v in tensors.items(): - headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]} - offset += v.nbytes() - j = json.dumps(headers, separators=(',', ':')) - j += "\x20"*((8-len(j)%8)%8) - pathlib.Path(fn).unlink(missing_ok=True) - t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") - t[0:1].cast(dtypes.int64).assign([len(j)]) - t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu")) - for k,v in safe_load(t).items(): v.assign(tensors[k]) # state dict from collections import OrderedDict -def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]: - if isinstance(obj, tensor_type): return {prefix.strip('.'):obj} - if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple - if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) - if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type) - state_dict = {} - if isinstance(obj, (list, tuple)): - for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) - elif isinstance(obj, dict): - for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) - return state_dict -def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values()) + + +def get_state_dict(obj, prefix: str = "", tensor_type=Tensor) -> Dict[str, Tensor]: + if isinstance(obj, tensor_type): + return {prefix.strip("."): obj} + if hasattr(obj, "_asdict"): + return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple + if isinstance(obj, OrderedDict): + return get_state_dict(dict(obj), prefix, tensor_type) + if hasattr(obj, "__dict__"): + return get_state_dict(obj.__dict__, prefix, tensor_type) + state_dict = {} + if isinstance(obj, (list, tuple)): + for i, x in enumerate(obj): + state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) + elif isinstance(obj, dict): + for k, v in obj.items(): + state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) + return state_dict + + +def get_parameters(obj) -> List[Tensor]: + return list(get_state_dict(obj).values()) + def load_state_dict(model, state_dict, strict=True, verbose=True): - with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): - model_state_dict = get_state_dict(model) - if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) - for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)): - t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}") - if k not in state_dict and not strict: - if DEBUG >= 1: print(f"WARNING: not loading {k}") - continue - v.assign(state_dict[k].to(v.device)).realize() + with Timing( + "loaded weights in ", + lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s", + ): + model_state_dict = get_state_dict(model) + if DEBUG >= 1 and len(state_dict) > len(model_state_dict): + print( + "WARNING: unused weights in state_dict", + sorted(list(state_dict.keys() - model_state_dict.keys())), + ) + for k, v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)): + t.set_description( + f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}" + ) + if k not in state_dict and not strict: + if DEBUG >= 1: + print(f"WARNING: not loading {k}") + continue + v.assign(state_dict[k].to(v.device)).realize() + # torch support! -def torch_load(fn:str): - t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") - offsets: Dict[Union[str, int], int] = {} - lens: Dict[Union[str, int], int] = {} - def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None): - #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) - lens[storage[2]] = storage[4] * storage[1].itemsize - if storage[2] not in offsets: return None - byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize - ret = t[byte_offset:byte_offset+prod(size)].cast(storage[1]) - # convert bfloat16 -> float16 using LLVM for Llama 2 - # upstream LLaMA also does this conversion: - # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 - # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support - if storage[1] == dtypes.bfloat16: - ret = ret.bitcast(dtypes.uint16).to("CPU").cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).to(Device.DEFAULT).half() - #ret = ret.to("LLVM").half().to(Device.DEFAULT) +def torch_load(fn: str): + t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") - # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk - shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] - permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])] - if tuple(permute_indexes) != tuple(range(len(permute_indexes))): - intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)]) - assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides" - if DEBUG >= 3: print(f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}") - # TODO: find a nice way to support all shapetracker on disktensors - ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes) + offsets: Dict[Union[str, int], int] = {} + lens: Dict[Union[str, int], int] = {} - return ret.reshape(size) + def _rebuild_tensor_v2( + storage, + storage_offset, + size, + stride, + requires_grad=None, + backward_hooks=None, + metadata=None, + ): + # print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) + lens[storage[2]] = storage[4] * storage[1].itemsize + if storage[2] not in offsets: + return None + byte_offset = offsets[storage[2]] + storage_offset * storage[1].itemsize + ret = t[byte_offset : byte_offset + prod(size)].cast(storage[1]) + # convert bfloat16 -> float16 using LLVM for Llama 2 + # upstream LLaMA also does this conversion: + # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 + # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support + if storage[1] == dtypes.bfloat16: + ret = ( + ret.bitcast(dtypes.uint16) + .to("CPU") + .cast(dtypes.uint32) + .mul(1 << 16) + .bitcast(dtypes.float32) + .to(Device.DEFAULT) + .half() + ) + # ret = ret.to("LLVM").half().to(Device.DEFAULT) - class Parameter: - def __setstate__(self, state): self.tensor = state[0] + # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk + shape_strides = [(s, st) for s, st in zip(size, stride) if s != 1] + permute_indexes = [ + len(shape_strides) - 1 - y for y in argsort([x[1] for x in shape_strides]) + ] + if tuple(permute_indexes) != tuple(range(len(permute_indexes))): + intermediate_shape = tuple( + [shape_strides[x][0] for x in argsort(permute_indexes)] + ) + assert tuple( + [shape_strides[i][1] for i in argsort(permute_indexes)] + ) == strides_for_shape(intermediate_shape), "nonpermutable strides" + if DEBUG >= 3: + print( + f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}" + ) + # TODO: find a nice way to support all shapetracker on disktensors + ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes) - deserialized_objects: Dict[str, Any] = {} - intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, - "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} - whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed - class Dummy: pass - class TorchPickle(pickle.Unpickler): - def find_class(self, module, name): - module_root = module.split(".")[0] - if module_root not in whitelist: - if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}") - return Dummy - return intercept[name] if module_root == "torch" else super().find_class(module, name) - def persistent_load(self, pid): return deserialized_objects[pid] if pid in deserialized_objects else pid + return ret.reshape(size) - if tuple(t[0:2].numpy()) == (0x50, 0x4b): - myzip = zipfile.ZipFile(fn, 'r') - base_name = myzip.namelist()[0].split('/', 1)[0] - for n in myzip.namelist(): - if n.startswith(f'{base_name}/data/'): - with myzip.open(n) as myfile: - offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore - with myzip.open(f'{base_name}/data.pkl') as myfile: - return TorchPickle(myfile).load() - elif bytes(t[0:0xe].numpy()) == b"././@PaxHeader": # TODO: is this how you detect a tarfile? - with tarfile.open(fn, "r") as tar: - storages_offset = tar.getmember('storages').offset_data - f = unwrap(tar.extractfile('storages')) - for i in range(TorchPickle(f).load()): # num_storages - (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('= 2: + print(f"WARNING: returning Dummy for {module} {name}") + return Dummy + return ( + intercept[name] + if module_root == "torch" + else super().find_class(module, name) + ) + + def persistent_load(self, pid): + return deserialized_objects[pid] if pid in deserialized_objects else pid + + if tuple(t[0:2].numpy()) == (0x50, 0x4B): + myzip = zipfile.ZipFile(fn, "r") + base_name = myzip.namelist()[0].split("/", 1)[0] + for n in myzip.namelist(): + if n.startswith(f"{base_name}/data/"): + with myzip.open(n) as myfile: + offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore + with myzip.open(f"{base_name}/data.pkl") as myfile: + return TorchPickle(myfile).load() + elif ( + bytes(t[0:0xE].numpy()) == b"././@PaxHeader" + ): # TODO: is this how you detect a tarfile? + with tarfile.open(fn, "r") as tar: + storages_offset = tar.getmember("storages").offset_data + f = unwrap(tar.extractfile("storages")) + for i in range(TorchPickle(f).load()): # num_storages + (key, _, storage_type), sz = ( + TorchPickle(f).load(), + struct.unpack(" Tuple[LazyBuffer, ...]: return tuple(dedup(sum([x.buffers for x in self.src], ()))) - @functools.cached_property - def hash(self): return hash((self.op,self.src, self.arg)) - def __hash__(self): return self.hash + op: Op + src: Tuple[Union[LazyOp, LazyBuffer], ...] + arg: Any = None - def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg) - def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()] + def __repr__(self): + return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})" - def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer': - assert isinstance(self.op, (UnaryOps, BinaryOps, TernaryOps)) - srcs = [z.replace_with_movement_ops(ops) for z in self.src] - return srcs[0].e(self.op, *srcs[1:], arg=self.arg) + @functools.cached_property + def buffers(self) -> Tuple[LazyBuffer, ...]: + return tuple(dedup(sum([x.buffers for x in self.src], ()))) - @property - def st(self): raise NotImplementedError - @property - def realized(self): raise NotImplementedError - @property - def children(self): raise NotImplementedError + @functools.cached_property + def hash(self): + return hash((self.op, self.src, self.arg)) + + def __hash__(self): + return self.hash + + def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: + return LazyOp( + self.op, + tuple( + [ + y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] + for y in self.src + ] + ), + self.arg, + ) + + def get_lazyops(self) -> List[LazyOp]: + return [self] + [item for x in self.src for item in x.get_lazyops()] + + def replace_with_movement_ops( + self: LazyOp, ops: List[Tuple[MovementOps, Tuple[Any, ...]]] + ) -> "LazyBuffer": + assert isinstance(self.op, (UnaryOps, BinaryOps, TernaryOps)) + srcs = [z.replace_with_movement_ops(ops) for z in self.src] + return srcs[0].e(self.op, *srcs[1:], arg=self.arg) + + @property + def st(self): + raise NotImplementedError + + @property + def realized(self): + raise NotImplementedError + + @property + def children(self): + raise NotImplementedError + + # movement ops + def reshape(self, _): + raise NotImplementedError + + def pad(self, _): + raise NotImplementedError + + def expand(self, _): + raise NotImplementedError + + def permute(self, _): + raise NotImplementedError + + def shrink(self, _): + raise NotImplementedError + + def stride(self, _): + raise NotImplementedError - # movement ops - def reshape(self, _): raise NotImplementedError - def pad(self, _): raise NotImplementedError - def expand(self, _): raise NotImplementedError - def permute(self, _): raise NotImplementedError - def shrink(self, _): raise NotImplementedError - def stride(self, _): raise NotImplementedError # **************** independent FlopCounter **************** + @dataclass class FlopCounter: - shape: Tuple[int, ...] - dtype: DType - flops: int - mem: Dict[int, int] - @property - def mem_estimate(self): return sum(self.mem.values()) - def consume_flops(self): - self.flops, ret = 0, self.flops - return ret + shape: Tuple[int, ...] + dtype: DType + flops: int + mem: Dict[int, int] + + @property + def mem_estimate(self): + return sum(self.mem.values()) + + def consume_flops(self): + self.flops, ret = 0, self.flops + return ret + InterpretedFlopCounter: Dict[Op, Callable] = { - BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), - BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.size()}), UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops - **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, - **{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, - **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, - TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, max(y.dtype, z.dtype), self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} + BufferOps.LOAD: lambda arg: FlopCounter( + arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize * arg.st.size()} + ), + BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), + BufferOps.STORE: lambda self, arg: FlopCounter( + arg.st.shape, + arg.dtype, + self.consume_flops(), + {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.size()}, + ), + UnaryOps.CAST: lambda self, arg: FlopCounter( + self.shape, arg[0], self.consume_flops(), self.mem + ), # cast uses no flops + **{ + op: lambda self: FlopCounter( + self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem + ) + for op in UnaryOps + if op != UnaryOps.CAST + }, + **{ + op: lambda self, y: FlopCounter( + self.shape, + max(self.dtype, y.dtype), + self.consume_flops() + y.consume_flops() + prod(self.shape), + {**self.mem, **y.mem}, + ) + for op in BinaryOps + }, + **{ + op: lambda self, new_shape: FlopCounter( + new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem + ) + for op in ReduceOps + }, + TernaryOps.WHERE: lambda self, y, z: FlopCounter( + self.shape, + max(y.dtype, z.dtype), + self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), + {**self.mem, **y.mem, **z.mem}, + ), +} + @functools.lru_cache(None) -def get_lazyop_info(ast:LazyOp) -> FlopCounter: - @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs - def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) - return run_ast(ast) +def get_lazyop_info(ast: LazyOp) -> FlopCounter: + @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs + def run_ast(ast): + return InterpretedFlopCounter[ast.op]( + *( + [run_ast(x) for x in ast.src] + + ([ast.arg] if ast.arg is not None else []) + ) + ) + + return run_ast(ast) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index e43ac139a..b0c96f3a9 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -5,33 +5,61 @@ from tinygrad.graph import log_schedule_item, print_tree from tinygrad.helpers import prod from tinygrad.shape.symbolic import Variable + class CustomOp(JITRunner): - def __init__(self, fxn): - self.fxn = fxn - super().__init__() - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) + def __init__(self, fxn): + self.fxn = fxn + super().__init__() -def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]: - assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" - if si.ast.op is LoadOps.EMPTY: return None - if si.ast.op is LoadOps.FROM: return BufferCopy - if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg) - return Device[si.out.device].get_runner(si.ast) + def __call__( + self, + rawbufs: List[Buffer], + var_vals: Dict[Variable, int], + wait=False, + jit=False, + ): + self.fxn(*rawbufs) -def run_schedule(schedule:List[ScheduleItem], disable_logging=False): - while len(schedule): - si = schedule.pop(0) - if not disable_logging: log_schedule_item(si) - assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized" - # get the program - prg = lower_schedule_item(si) +def lower_schedule_item(si: ScheduleItem) -> Optional[JITRunner]: + assert ( + all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM + ), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" + if si.ast.op is LoadOps.EMPTY: + return None + if si.ast.op is LoadOps.FROM: + return BufferCopy + if si.ast.op is LoadOps.CUSTOM: + return CustomOp(si.ast.arg) + return Device[si.out.device].get_runner(si.ast) - # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape - si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \ - Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype) - del si.out.op - for v in si.out.views: del v.op - # run the function (put it in JIT) - if prg: prg.exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals) +def run_schedule(schedule: List[ScheduleItem], disable_logging=False): + while len(schedule): + si = schedule.pop(0) + if not disable_logging: + log_schedule_item(si) + assert all( + x.realized for x in si.inputs + ), "can't run schedule, some inputs aren't realized" + + # get the program + prg = lower_schedule_item(si) + + # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape + si.out.realized = ( + si.out.output_buffer + if si.out.output_buffer is not None + else Buffer( + si.out.device, + prod((s if isinstance(s, int) else s.max for s in si.out.shape)), + si.out.dtype, + ) + ) + del si.out.op + for v in si.out.views: + del v.op + + # run the function (put it in JIT) + if prg: + prg.exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index a89aba760..9f1151acf 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -5,260 +5,436 @@ from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens + class CStyleLanguage(NamedTuple): - size_prefix: str = "int" - generic_var_prefix: str = "" - kernel_prefix: str = "" - buffer_prefix: str = "" - buffer_suffix: str = "" - smem_align: str = "" - smem_prefix: str = "" - smem_prefix_for_cast: bool = True - arg_int_prefix: str = "" - barrier: str = "" - xid: List[str] = [] - gid: List[str] = [] - lid: List[str] = [] - global_max: List[int] = [] - local_max: List[int] = [] - extra_args: List[str] = [] - float4: Optional[str] = None - half_prekernel: Optional[str] = None - uses_vload: bool = False - external_local_bufs: bool = False - uses_ptr_arithmetic: bool = False - launch_bounds: bool = False - code_for_op: Dict = { - UnaryOps.NEG: lambda x,dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})", - UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", - UnaryOps.SIN: lambda x,dtype: f"sin({x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", - BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", - BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", - BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", - BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", - TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" - } + size_prefix: str = "int" + generic_var_prefix: str = "" + kernel_prefix: str = "" + buffer_prefix: str = "" + buffer_suffix: str = "" + smem_align: str = "" + smem_prefix: str = "" + smem_prefix_for_cast: bool = True + arg_int_prefix: str = "" + barrier: str = "" + xid: List[str] = [] + gid: List[str] = [] + lid: List[str] = [] + global_max: List[int] = [] + local_max: List[int] = [] + extra_args: List[str] = [] + float4: Optional[str] = None + half_prekernel: Optional[str] = None + uses_vload: bool = False + external_local_bufs: bool = False + uses_ptr_arithmetic: bool = False + launch_bounds: bool = False + code_for_op: Dict = { + UnaryOps.NEG: lambda x, dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})", + UnaryOps.EXP2: lambda x, dtype: f"exp2({x})", + UnaryOps.LOG2: lambda x, dtype: f"log2({x})", + UnaryOps.SIN: lambda x, dtype: f"sin({x})", + UnaryOps.SQRT: lambda x, dtype: f"sqrt({x})", + BinaryOps.ADD: lambda a, b, dtype: f"({a}+{b})", + BinaryOps.SUB: lambda a, b, dtype: f"({a}-{b})", + BinaryOps.MUL: lambda a, b, dtype: f"({a}*{b})", + BinaryOps.DIV: lambda a, b, dtype: f"({a}/{b})", + BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})", + BinaryOps.MOD: lambda a, b, dtype: f"({a}%{b})", + BinaryOps.CMPLT: lambda a, b, dtype: f"({a}<{b})", + TernaryOps.MULACC: lambda a, b, c, dtype: f"(({a}*{b})+{c})", + TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})", + } - # returns a str expression of the casted xs with the given type - def render_cast(self, x:List[str], var_dtype:DType) -> str: - if len(x) == 1: return f"({var_dtype.name})({x[0]})" - assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" - assert self.float4 is not None, "vectorized cast is not supported on this platform" - return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})" + # returns a str expression of the casted xs with the given type + def render_cast(self, x: List[str], var_dtype: DType) -> str: + if len(x) == 1: + return f"({var_dtype.name})({x[0]})" + assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" + assert ( + self.float4 is not None + ), "vectorized cast is not supported on this platform" + return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})" - # returns a str expression of the const with the given type - def render_const(self, x:Union[float,int,bool], var_dtype) -> str: - if math.isnan(x): val = "NAN" - elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" - else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower() - return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val + # returns a str expression of the const with the given type + def render_const(self, x: Union[float, int, bool], var_dtype) -> str: + if math.isnan(x): + val = "NAN" + elif math.isinf(x): + val = ("-" if x < 0 else "") + "INFINITY" + else: + val = ( + f"{float(x)}f" + if dtypes.is_float(var_dtype) + else f"{int(x)}" + if dtypes.is_int(var_dtype) + else f"{bool(x)}".lower() + ) + return ( + self.render_cast([val] * var_dtype.sz, var_dtype) + if var_dtype.sz > 1 + or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] + else val + ) - # returns a str expression of the loaded value with the output type - def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: - if isinstance(buf_dtype, ImageDType): - assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}" - return f"read_imagef({buf_name}, smp, {idx})" - if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: - return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" - if output_dtype.sz > 1: - out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" - else: - out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" + # returns a str expression of the loaded value with the output type + def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: + if isinstance(buf_dtype, ImageDType): + assert output_dtype == dtypes.float.vec( + 4 + ), f"images must be float4, getting {output_dtype}" + return f"read_imagef({buf_name}, smp, {idx})" + if ( + self.uses_vload + and buf_dtype.scalar() == dtypes.float16 + and output_dtype.scalar() != dtypes.float16 + ): + return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" + if output_dtype.sz > 1: + out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" + else: + out_val = ( + f"*({buf_name}+{idx})" + if self.uses_ptr_arithmetic + else f"{buf_name}[{idx}]" + ) - return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val + return ( + self.render_cast([out_val], output_dtype) + if output_dtype != buf_dtype + else out_val + ) - def render_local(self, name:str, size:int): - return self.smem_align + self.smem_prefix + f"float {name}[{size}];" + def render_local(self, name: str, size: int): + return self.smem_align + self.smem_prefix + f"float {name}[{size}];" - def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str: - return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{" + def render_for( + self, expr: str, _min: Union[int, str], _max: Union[int, str] + ) -> str: + return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{" - def render_if(self, cond: str): - return f"if ({cond}) {{" + def render_if(self, cond: str): + return f"if ({cond}) {{" - def render_conditional(self, cond: str, x:str, y:str) -> str: - return f"({cond})?({x}):{y}" + def render_conditional(self, cond: str, x: str, y: str) -> str: + return f"({cond})?({x}):{y}" - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: - tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" - buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else - self.arg_int_prefix if dtype == dtypes._arg_int32 else - ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)] - prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] + - [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + - [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) - if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg]) - return prg + def render_kernel( + self, + function_name: str, + kernel: List[str], + bufs: List[Tuple[str, DType]], + local_size: List[int], + prekernel: List[str], + ) -> str: + tmp = ( + "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" + if any(isinstance(dtype, ImageDType) for _, dtype in bufs) + else "" + ) + buftypes = [ + ( + name, + f"{'read_only' if i > 0 else 'write_only'} image2d_t" + if dtype.name.startswith("image") + else self.arg_int_prefix + if dtype == dtypes._arg_int32 + else ("const " if i > 0 else "") + + self.buffer_prefix + + dtype.name + + "*" + + self.buffer_suffix, + ) + for i, (name, dtype) in enumerate(bufs) + ] + prg = "".join( + [ + f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(", + ] + + [", ".join([f"{t} {name}" for name, t in buftypes] + self.extra_args)] + + [") {\n" + tmp] + + ["\n".join(kernel), "\n}"] + ) + if self.half_prekernel and any(dtype == dtypes.float16 for _, dtype in bufs): + prg = "".join([f"{self.half_prekernel}", "\n", prg]) + return prg - # returns a str statement that does the store - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: - if isinstance(buf_dtype, ImageDType): - assert var_dtype == dtypes.float.vec(4), "images must be float4" - return f"write_imagef({buf_name}, {idx}, {var_name});" - if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: - return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" - if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" - return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" + # returns a str statement that does the store + def render_store( + self, + buf_name: str, + buf_dtype: DType, + var_name: str, + var_dtype: DType, + idx: str, + local=False, + ) -> str: + if isinstance(buf_dtype, ImageDType): + assert var_dtype == dtypes.float.vec(4), "images must be float4" + return f"write_imagef({buf_name}, {idx}, {var_name});" + if ( + self.uses_vload + and buf_dtype.scalar() == dtypes.float16 + and var_dtype.scalar() != dtypes.float16 + ): + return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" + if var_dtype.sz > 1: + return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + return ( + f"*({buf_name}+{idx}) = {var_name};" + if self.uses_ptr_arithmetic + else f"{buf_name}[{idx}] = {var_name};" + ) -def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: - local_size: List[int] = [] - kernel,prekernel,bufs = [],[],[] - #pend_close = None - depth = 1 - def kk(s): kernel.append(" "*depth+s) - c: DefaultDict[str, int] = defaultdict(int) - r: Dict[UOp, str] = {} - def ssa(u, prefix="t"): - nonlocal c, r - c[prefix] += 1 - r[u]=f"{prefix}{c[prefix]-1}" - return r[u] +def uops_to_cstyle( + lang: CStyleLanguage, function_name: str, uops: List[UOp] +) -> Tuple[str, Dict]: + local_size: List[int] = [] + kernel, prekernel, bufs = [], [], [] + # pend_close = None + depth = 1 - child_count: DefaultDict[UOp, int] = defaultdict(int) - for ru in uops: - for v in ru.vin: - child_count[v] += 1 + def kk(s): + kernel.append(" " * depth + s) - for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == UOps.LOOP: - kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]])) - depth += 1 - elif uop == UOps.IF: - kk(lang.render_if(r[vin[0]])) - depth += 1 - elif uop == UOps.BARRIER: - kk(lang.barrier) - elif uop == UOps.END: - depth -= 1 - kk("}") - elif uop == UOps.WMMA: - if args[0] == "METAL": - assert dtype == dtypes.float.vec(2), "output dtype of METAL TC is _float2" - # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) - output = ssa(u, 'wmma') - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};") - kk("{ simdgroup_float8x8 a,b,c;") - kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};") - kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};") - kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};") - kk("simdgroup_multiply_accumulate(c, a, b, c);") - kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}") - elif args[0] == "HIP": - assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8" - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") - else: - raise NotImplementedError(f"WMMA not implemented for {args}") - elif uop == UOps.ALU: - assert dtype is not None - # remove parens if ALU types are the same. TODO: can do more here - if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}: - val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype) - else: - val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) - assert child_count[u] != 0, f"childless ALU op found {u}" - if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue - r[u] = val - else: - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};") - elif uop == UOps.DEFINE_ACC: - assert dtype is not None - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") - elif uop == UOps.SPECIAL: - xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid) - kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") - if args[1].startswith("l"): local_size.append(args[2]) - r[u] = args[1] - elif uop == UOps.CONST: - r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" - elif uop == UOps.LOAD: - assert dtype is not None - val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL) - if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) - kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};") - elif uop == UOps.PHI: - kk(f"{r[vin[0]]} = {r[vin[1]]};") - r[u] = r[vin[0]] - elif uop == UOps.STORE: - assert vin[0].dtype is not None and vin[2].dtype is not None - if len(vin) > 3: kk(lang.render_if(r[vin[3]])) - kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)) - if len(vin) > 3: kk("}") - elif uop == UOps.CAST and dtype is not None: - val = lang.render_cast([r[x] for x in vin], dtype) - if child_count[u] <= 1: r[u] = val - else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};") - elif uop == UOps.DEFINE_LOCAL: - if lang.external_local_bufs: - prekernel.append(lang.render_local(args[0], args[1])) - else: - kk(lang.render_local(args[0], args[1])) - r[u] = args[0] - elif uop == UOps.DEFINE_GLOBAL: - bufs.append(args) - r[u] = args[0] - elif uop == UOps.GEP: - if cast(DType, vin[0].dtype).sz > 4: - r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP - else: - r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" - else: - raise RuntimeError(f"failed to render {uop}") + c: DefaultDict[str, int] = defaultdict(int) + r: Dict[UOp, str] = {} + + def ssa(u, prefix="t"): + nonlocal c, r + c[prefix] += 1 + r[u] = f"{prefix}{c[prefix]-1}" + return r[u] + + child_count: DefaultDict[UOp, int] = defaultdict(int) + for ru in uops: + for v in ru.vin: + child_count[v] += 1 + + for u in uops: + uop, dtype, vin, args = u.uop, u.dtype, u.vin, u.arg + if uop == UOps.LOOP: + kk(lang.render_for(ssa(u, "ridx"), r[vin[0]], r[vin[1]])) + depth += 1 + elif uop == UOps.IF: + kk(lang.render_if(r[vin[0]])) + depth += 1 + elif uop == UOps.BARRIER: + kk(lang.barrier) + elif uop == UOps.END: + depth -= 1 + kk("}") + elif uop == UOps.WMMA: + if args[0] == "METAL": + assert dtype == dtypes.float.vec( + 2 + ), "output dtype of METAL TC is _float2" + # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) + output = ssa(u, "wmma") + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};" + ) + kk("{ simdgroup_float8x8 a,b,c;") + kk( + f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};" + ) + kk( + f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};" + ) + kk( + f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};" + ) + kk("simdgroup_multiply_accumulate(c, a, b, c);") + kk( + f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}" + ) + elif args[0] == "HIP": + assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8" + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});" + ) + else: + raise NotImplementedError(f"WMMA not implemented for {args}") + elif uop == UOps.ALU: + assert dtype is not None + # remove parens if ALU types are the same. TODO: can do more here + if ( + vin[0].uop == UOps.ALU + and vin[0].arg == args + and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL} + ): + val = lang.code_for_op[args]( + strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype + ) + else: + val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) + assert child_count[u] != 0, f"childless ALU op found {u}" + if ( + child_count[u] <= 1 or dtypes.is_int(dtype) + ) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue + r[u] = val + else: + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};" + ) + elif uop == UOps.DEFINE_ACC: + assert dtype is not None + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};" + ) + elif uop == UOps.SPECIAL: + xid = ( + lang.gid + if args[1].startswith("g") + else (lang.xid if args[1].startswith("i") else lang.lid) + ) + kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") + if args[1].startswith("l"): + local_size.append(args[2]) + r[u] = args[1] + elif uop == UOps.CONST: + r[u] = ( + lang.render_const(args, dtype) + if args >= 0 + else f"({lang.render_const(args, dtype)})" + ) + elif uop == UOps.LOAD: + assert dtype is not None + val = lang.render_load( + dtype, + r[vin[0]], + vin[0].dtype, + strip_parens(r[vin[1]]), + vin[0].uop == UOps.DEFINE_LOCAL, + ) + if len(vin) > 3: + val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};" + ) + elif uop == UOps.PHI: + kk(f"{r[vin[0]]} = {r[vin[1]]};") + r[u] = r[vin[0]] + elif uop == UOps.STORE: + assert vin[0].dtype is not None and vin[2].dtype is not None + if len(vin) > 3: + kk(lang.render_if(r[vin[3]])) + kk( + lang.render_store( + r[vin[0]], + vin[0].dtype, + r[vin[2]], + vin[2].dtype, + strip_parens(r[vin[1]]), + vin[0].uop == UOps.DEFINE_LOCAL, + ) + ) + if len(vin) > 3: + kk("}") + elif uop == UOps.CAST and dtype is not None: + val = lang.render_cast([r[x] for x in vin], dtype) + if child_count[u] <= 1: + r[u] = val + else: + kk( + f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};" + ) + elif uop == UOps.DEFINE_LOCAL: + if lang.external_local_bufs: + prekernel.append(lang.render_local(args[0], args[1])) + else: + kk(lang.render_local(args[0], args[1])) + r[u] = args[0] + elif uop == UOps.DEFINE_GLOBAL: + bufs.append(args) + r[u] = args[0] + elif uop == UOps.GEP: + if cast(DType, vin[0].dtype).sz > 4: + r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP + else: + r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" + else: + raise RuntimeError(f"failed to render {uop}") + + return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {} - return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {} class OpenCLLanguage(CStyleLanguage): - kernel_prefix = "__kernel " - buffer_prefix = "__global " - smem_align = "__attribute__ ((aligned (16))) " - smem_prefix = "__local " - arg_int_prefix = "const int" - half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" - barrier = "barrier(CLK_LOCAL_MEM_FENCE);" - float4 = "(float4)" - gid = [f'get_group_id({i})' for i in range(3)] - lid = [f'get_local_id({i})' for i in range(3)] - xid = [f'get_global_id({i})' for i in range(3)] - uses_vload = True - # NOTE: mad is used so the loads aren't reordered into the math on 845 - code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"} + kernel_prefix = "__kernel " + buffer_prefix = "__global " + smem_align = "__attribute__ ((aligned (16))) " + smem_prefix = "__local " + arg_int_prefix = "const int" + half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" + barrier = "barrier(CLK_LOCAL_MEM_FENCE);" + float4 = "(float4)" + gid = [f"get_group_id({i})" for i in range(3)] + lid = [f"get_local_id({i})" for i in range(3)] + xid = [f"get_global_id({i})" for i in range(3)] + uses_vload = True + # NOTE: mad is used so the loads aren't reordered into the math on 845 + code_for_op = { + **CStyleLanguage().code_for_op, + TernaryOps.MULACC: lambda a, b, c, dtype: f"mad({a},{b},{c})", + } + + OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) + class MetalLanguage(CStyleLanguage): - kernel_prefix = "#include \nusing namespace metal;\nkernel " - buffer_prefix = "device " - smem_prefix = "threadgroup " - arg_int_prefix = "constant int&" - barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);" - float4 = "float4" - uses_ptr_arithmetic=True - gid = [f"gid.{chr(120+i)}" for i in range(3)] - lid = [f"lid.{chr(120+i)}" for i in range(3)] - extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'] + kernel_prefix = "#include \nusing namespace metal;\nkernel " + buffer_prefix = "device " + smem_prefix = "threadgroup " + arg_int_prefix = "constant int&" + barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);" + float4 = "float4" + uses_ptr_arithmetic = True + gid = [f"gid.{chr(120+i)}" for i in range(3)] + lid = [f"lid.{chr(120+i)}" for i in range(3)] + extra_args = [ + "uint3 gid [[threadgroup_position_in_grid]]", + "uint3 lid [[thread_position_in_threadgroup]]", + ] + + MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage()) + class CUDALanguage(CStyleLanguage): - kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ " - smem_prefix = "__shared__ " - smem_prefix_for_cast = False - arg_int_prefix = "const int" - barrier = "__syncthreads();" - float4 = "make_float4" - gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] - lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] - xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] - code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})"} - half_prekernel = """ + kernel_prefix = '#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern "C" __global__ ' + smem_prefix = "__shared__ " + smem_prefix_for_cast = False + arg_int_prefix = "const int" + barrier = "__syncthreads();" + float4 = "make_float4" + gid = [f"blockIdx.{chr(120+i)}" for i in range(3)] + lid = [f"threadIdx.{chr(120+i)}" for i in range(3)] + xid = [ + f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})" + for i in range(3) + ] + code_for_op = { + **CStyleLanguage().code_for_op, + BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})" + if dtype != dtypes.half + else f"__hmax({a},{b})", + } + half_prekernel = """ #include struct half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; } """ + + CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) + class HIPLanguage(CStyleLanguage): - kernel_prefix = "#include \n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """ + kernel_prefix = ( + '#include \n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(""))' + + """ __device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); } __device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); } __device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); } @@ -268,15 +444,18 @@ class HIPLanguage(CStyleLanguage): typedef float float8 __attribute__((ext_vector_type(8))); __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; } extern "C" __global__ """ - launch_bounds = True - smem_prefix = "__shared__ " - smem_prefix_for_cast=False - barrier = "__syncthreads();" - float4 = "make_float4" - uses_vload=True - uses_ptr_arithmetic=True - arg_int_prefix = "const int" - half_prekernel = "#include \n" + """ + ) + launch_bounds = True + smem_prefix = "__shared__ " + smem_prefix_for_cast = False + barrier = "__syncthreads();" + float4 = "make_float4" + uses_vload = True + uses_ptr_arithmetic = True + arg_int_prefix = "const int" + half_prekernel = ( + "#include \n" + + """ typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; } typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; } typedef _Float16 half16 __attribute__((ext_vector_type(16))); __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, half e, half f, half g, half h, half i, half j, half k, half l) { return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; } @@ -307,52 +486,110 @@ __device__ half operator*(const unsigned short &a, const half &b) { return __hmu __device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); } __device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); } """ - gid = [f'blockIdx.{chr(120+i)}' for i in range(3)] - lid = [f'threadIdx.{chr(120+i)}' for i in range(3)] - xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)] - code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})"} + ) + gid = [f"blockIdx.{chr(120+i)}" for i in range(3)] + lid = [f"threadIdx.{chr(120+i)}" for i in range(3)] + xid = [ + f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})" + for i in range(3) + ] + code_for_op = { + **CStyleLanguage().code_for_op, + BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})" + if dtype != dtypes.half + else f"hmax({a},{b})", + TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})" + if dtype != dtypes.half + else f"(half)({a}!=0?{b}:{c})", + } + + HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) + # TODO: how much of this can be merged with above? class WGSLLanguage(CStyleLanguage): - gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] - lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] - size_prefix = "let" - barrier="workgroupBarrier();" - generic_var_prefix = "var " - external_local_bufs = True - code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" } - type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"} + gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] + lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] + size_prefix = "let" + barrier = "workgroupBarrier();" + generic_var_prefix = "var " + external_local_bufs = True + code_for_op = { + **CStyleLanguage().code_for_op, + BinaryOps.CMPLT: lambda x, y, dtype: f"f32({x}<{y})", + TernaryOps.MULACC: lambda x, y, z, dtype: f"fma({x},{y},{z})", + TernaryOps.WHERE: lambda a, b, c, dtype: f"select({c},{b},{a}!=0.)", + } + type_map = { + dtypes.float: "f32", + dtypes.half: "f16", + dtypes.int32: "i32", + dtypes.uint32: "u32", + dtypes.bool: "bool", + } - def render_local(self, name: str, size: int): - return f"var {name}: array;" + def render_local(self, name: str, size: int): + return f"var {name}: array;" - def render_const(self, x:Union[float,int], var_dtype) -> str: - if math.isnan(x): return "nan()" - elif math.isinf(x): return ("-" if x < 0 else "") + "0x1.fffffep+127f" - return f"({super().render_const(x, var_dtype)})" + def render_const(self, x: Union[float, int], var_dtype) -> str: + if math.isnan(x): + return "nan()" + elif math.isinf(x): + return ("-" if x < 0 else "") + "0x1.fffffep+127f" + return f"({super().render_const(x, var_dtype)})" - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: - local_size = local_size[::-1] if local_size else [1] - bind_it = iter(range(len(bufs))) - prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" - prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var {name}: array<{self.type_map[dtype]}>;" for name,dtype in bufs]) - prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" - return prg + def render_kernel( + self, + function_name: str, + kernel: List[str], + bufs: List[Tuple[str, DType]], + local_size: List[int], + prekernel: List[str], + ) -> str: + local_size = local_size[::-1] if local_size else [1] + bind_it = iter(range(len(bufs))) + prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" + prg += "\n".join( + prekernel + + [ + f"@group(0) @binding({next(bind_it)}) var {name}: array<{self.type_map[dtype]}>;" + for name, dtype in bufs + ] + ) + prg += ( + f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + + "\n".join(kernel) + + "\n}" + ) + return prg - def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str: - return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{" + def render_for( + self, expr: str, _min: Union[int, str], _max: Union[int, str] + ) -> str: + return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{" - def render_if(self, cond: str): - return f"if (bool({cond})) {{" + def render_if(self, cond: str): + return f"if (bool({cond})) {{" - def render_conditional(self, cond:str, x:str, y:str) -> str: - return f"select(f32({y}), {x}, bool({cond}))" + def render_conditional(self, cond: str, x: str, y: str) -> str: + return f"select(f32({y}), {x}, bool({cond}))" + + def render_cast(self, x: List[str], var_dtype: DType) -> str: + if self.type_map[var_dtype]: + return f"{self.type_map[var_dtype]}({x[0]})" + raise NotImplementedError(f"no cast for {var_dtype}") + + def render_store( + self, + buf_name: str, + buf_dtype: DType, + var_name: str, + var_dtype: DType, + idx, + local=False, + ) -> str: + return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};" - def render_cast(self, x:List[str], var_dtype:DType) -> str: - if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})" - raise NotImplementedError(f"no cast for {var_dtype}") - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: - return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};" WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage()) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 99801ece5..e8cf64b3c 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -4,151 +4,291 @@ from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps -LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf +LLVM_FAST_MATH_FLAGS = ( + "nsz", + "arcp", + "contract", + "afn", + "reassoc", +) # All from fast math, but nnan and ninf + + +def is_bool(t: ir.Type): + return isinstance(t, ir.IntType) and t.width == 1 + -def is_bool(t:ir.Type): return isinstance(t, ir.IntType) and t.width == 1 code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.NEG: lambda builder,x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) if is_bool(x.type) else builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS), - UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - UnaryOps.SQRT: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sqrt', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS), - BinaryOps.ADD: lambda builder,x,y: builder.add(x,y) if isinstance(x.type, ir.IntType) else builder.fadd(x,y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.SUB: lambda builder,x,y: builder.sub(x,y) if isinstance(x.type, ir.IntType) else builder.fsub(x,y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.MUL: lambda builder,x,y: builder.mul(x,y) if isinstance(x.type, ir.IntType) else builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.DIV: lambda builder,x,y: builder.sdiv(x,y) if isinstance(x.type, ir.IntType) else builder.fdiv(x,y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.CMPLT: lambda builder,x,y: builder.icmp_unsigned("<", x, y) if is_bool(x.type) else builder.icmp_signed("<", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), - BinaryOps.MAX: lambda builder,x,y: builder.select(builder.icmp_unsigned(">", x, y) if is_bool(x.type) else builder.icmp_signed(">", x, y) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), x, y), - BinaryOps.MOD: lambda builder,x,y: builder.urem(x,y) if is_bool(x.type) else builder.srem(x,y) if isinstance(x.type, ir.IntType) else builder.frem(x,y), - TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS), - TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.trunc(x, ir.IntType(1)) if isinstance(x.type, ir.IntType) else builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS), y, z - ), + UnaryOps.NEG: lambda builder, x: builder.xor(x, ir.Constant(ir.IntType(1), 1)) + if is_bool(x.type) + else builder.neg(x) + if isinstance(x.type, ir.IntType) + else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS), + UnaryOps.EXP2: lambda builder, x: builder.call( + builder._block.module.declare_intrinsic("llvm.exp2", [ir.FloatType()]), + [x], + fastmath=LLVM_FAST_MATH_FLAGS, + ), + UnaryOps.LOG2: lambda builder, x: builder.call( + builder._block.module.declare_intrinsic("llvm.log2", [ir.FloatType()]), + [x], + fastmath=LLVM_FAST_MATH_FLAGS, + ), + UnaryOps.SIN: lambda builder, x: builder.call( + builder._block.module.declare_intrinsic("llvm.sin", [ir.FloatType()]), + [x], + fastmath=LLVM_FAST_MATH_FLAGS, + ), + UnaryOps.SQRT: lambda builder, x: builder.call( + builder._block.module.declare_intrinsic("llvm.sqrt", [ir.FloatType()]), + [x], + fastmath=LLVM_FAST_MATH_FLAGS, + ), + BinaryOps.ADD: lambda builder, x, y: builder.add(x, y) + if isinstance(x.type, ir.IntType) + else builder.fadd(x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.SUB: lambda builder, x, y: builder.sub(x, y) + if isinstance(x.type, ir.IntType) + else builder.fsub(x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.MUL: lambda builder, x, y: builder.mul(x, y) + if isinstance(x.type, ir.IntType) + else builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.DIV: lambda builder, x, y: builder.sdiv(x, y) + if isinstance(x.type, ir.IntType) + else builder.fdiv(x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.CMPLT: lambda builder, x, y: builder.icmp_unsigned("<", x, y) + if is_bool(x.type) + else builder.icmp_signed("<", x, y) + if isinstance(x.type, ir.IntType) + else builder.fcmp_unordered("<", x, y, flags=LLVM_FAST_MATH_FLAGS), + BinaryOps.MAX: lambda builder, x, y: builder.select( + builder.icmp_unsigned(">", x, y) + if is_bool(x.type) + else builder.icmp_signed(">", x, y) + if isinstance(x.type, ir.IntType) + else builder.fcmp_unordered(">", x, y, flags=LLVM_FAST_MATH_FLAGS), + x, + y, + ), + BinaryOps.MOD: lambda builder, x, y: builder.urem(x, y) + if is_bool(x.type) + else builder.srem(x, y) + if isinstance(x.type, ir.IntType) + else builder.frem(x, y), + TernaryOps.MULACC: lambda builder, x, y, z: builder.fadd( + builder.fmul(x, y, flags=LLVM_FAST_MATH_FLAGS), z, flags=LLVM_FAST_MATH_FLAGS + ), + TernaryOps.WHERE: lambda builder, x, y, z: builder.select( + builder.trunc(x, ir.IntType(1)) + if isinstance(x.type, ir.IntType) + else builder.fcmp_unordered( + "!=", x, ir.Constant(ir.FloatType(), 0), flags=LLVM_FAST_MATH_FLAGS + ), + y, + z, + ), +} + +dtype_to_llvm_dtype = { + dtypes.float64: ir.DoubleType(), + dtypes.float16: ir.HalfType(), + dtypes.bfloat16: ir.IntType(16), + dtypes.float32: ir.FloatType(), + dtypes.int8: ir.IntType(8), + dtypes.uint8: ir.IntType(8), + dtypes.bool: ir.IntType(1), + dtypes.int64: ir.IntType(64), + dtypes.int32: ir.IntType(32), + dtypes._arg_int32: ir.IntType(32), + dtypes.int16: ir.IntType(16), + dtypes.uint16: ir.IntType(16), + dtypes.uint32: ir.IntType(32), + dtypes.uint64: ir.IntType(64), } -dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16), dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)} def cast(bb, val, input_type, output_type): - if input_type == output_type: return val + if input_type == output_type: + return val - if dtypes.is_float(input_type): - if dtypes.is_float(output_type): - if output_type.itemsize > input_type.itemsize: return bb[-1].fpext(val, dtype_to_llvm_dtype[output_type]) - return bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) - if dtypes.is_int(output_type): - if dtypes.is_unsigned(output_type): return bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) - return bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) - if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0)) + if dtypes.is_float(input_type): + if dtypes.is_float(output_type): + if output_type.itemsize > input_type.itemsize: + return bb[-1].fpext(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if dtypes.is_unsigned(output_type): + return bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: + return bb[-1].fcmp_unordered( + "!=", + cast(bb, val, input_type, dtypes.float32), + ir.Constant(ir.FloatType(), 0), + ) - if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: - if output_type == dtypes.float16: - val = bb[-1].uitofp(val, ir.FloatType()) - return bb[-1].fptrunc(val, ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) - if dtypes.is_int(output_type): - if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) - return bb[-1].zext(val, dtype_to_llvm_dtype[output_type]) - if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0)) + if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: + if output_type == dtypes.float16: + val = bb[-1].uitofp(val, ir.FloatType()) + return bb[-1].fptrunc(val, ir.HalfType()) + if dtypes.is_float(output_type): + return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if input_type.itemsize > output_type.itemsize: + return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].zext(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: + return bb[-1].icmp_unsigned("!=", val, ir.Constant(val.type, 0)) - if dtypes.is_int(input_type): - if output_type == dtypes.float16: - val = bb[-1].sitofp(val, ir.FloatType()) - return bb[-1].fptrunc(val, ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type]) - if dtypes.is_int(output_type): - if input_type.itemsize > output_type.itemsize: return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) - return bb[-1].sext(val, dtype_to_llvm_dtype[output_type]) - if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0)) + if dtypes.is_int(input_type): + if output_type == dtypes.float16: + val = bb[-1].sitofp(val, ir.FloatType()) + return bb[-1].fptrunc(val, ir.HalfType()) + if dtypes.is_float(output_type): + return bb[-1].sitofp(val, dtype_to_llvm_dtype[output_type]) + if dtypes.is_int(output_type): + if input_type.itemsize > output_type.itemsize: + return bb[-1].trunc(val, dtype_to_llvm_dtype[output_type]) + return bb[-1].sext(val, dtype_to_llvm_dtype[output_type]) + if output_type == dtypes.bool: + return bb[-1].icmp_signed("!=", val, ir.Constant(val.type, 0)) - raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") + raise NotImplementedError( + f"cast from {input_type} -> {output_type} not implemented" + ) -def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args) -def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: - # all llvm stuff goes into a module - module = ir.Module(name=__file__) +def const(args, dtype): + return ir.Constant( + dtype_to_llvm_dtype[dtype], + int(args) + if dtypes.is_int(dtype) + else bool(args) + if dtype == dtypes.bool + else args, + ) - # extract global buffers - buf_to_dtype = {u.arg[0]:u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL} - buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} - # create llvm function - func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) - for a in func.args: - if a.type.is_pointer: a.add_attribute("noalias") +def uops_to_llvm_ir(function_name: str, uops: List[UOp]) -> Tuple[str, Dict]: + # all llvm stuff goes into a module + module = ir.Module(name=__file__) - # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations - func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"'])) - func.attributes.add('"no-nans-fp-math"="true"') + # extract global buffers + buf_to_dtype = {u.arg[0]: u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL} + buf_index = {x: i for i, x in enumerate(buf_to_dtype.keys())} - bb = [ir.IRBuilder(func.append_basic_block("entry"))] - loop_blocks: List = [] - reduce_phis: List = [] - # TODO: newvar probably shouldn't be optional - lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type + # create llvm function + func_dtypes = [ + (dtype_to_llvm_dtype[dtype], dtype) for dtype in buf_to_dtype.values() + ] + func = ir.Function( + module, + ir.FunctionType( + ir.VoidType(), + [x.as_pointer() if dt != dtypes._arg_int32 else x for x, dt in func_dtypes], + ), + name=function_name, + ) + for a in func.args: + if a.type.is_pointer: + a.add_attribute("noalias") - for bufname,dtype in buf_to_dtype.items(): - if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) + # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations + func.attributes._known = func.attributes._known.union( + frozenset(['"no-nans-fp-math"="true"']) + ) + func.attributes.add('"no-nans-fp-math"="true"') - for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == UOps.LOOP: - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) - bb[-2].branch(bb[-1]._block) + bb = [ir.IRBuilder(func.append_basic_block("entry"))] + loop_blocks: List = [] + reduce_phis: List = [] + # TODO: newvar probably shouldn't be optional + lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type - phis = [] - for rp in reduce_phis: - incoming = lvars[rp] - lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) - lvars[rp].add_incoming(incoming, bb[-2]._block) - phis.append((rp, lvars[rp])) + for bufname, dtype in buf_to_dtype.items(): + if dtype == dtypes._arg_int32: + lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) - lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") - lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block) - loop_blocks.append((bb[-1], phis)) - if uop == UOps.END: - block, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1)) - lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1]._block) - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block) - if uop == UOps.DEFINE_GLOBAL: - lvars[u] = func.args[buf_index[args[0]]] - if uop == UOps.DEFINE_ACC: - lvars[u] = const(args, dtype) - reduce_phis.append(u) - if uop == UOps.SPECIAL: - lvars[u] = lvars[args.expr] - if uop == UOps.CONST: - lvars[u] = const(args, dtype) - if uop == UOps.LOAD: - assert dtype is not None - if len(vin) > 2: - gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1)) - aug_idx = bb[-1].select(gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0)) - val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True)) - val = cast(bb, val, vin[0].dtype, dtype) - val = bb[-1].select(gate, val, lvars[vin[3]]) - else: - val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) - val = cast(bb, val, vin[0].dtype, dtype) - lvars[u] = val - if uop == UOps.PHI: - lvars[u] = lvars[vin[1]] - # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC - backward = vin[0] - while backward.uop == UOps.PHI: backward = backward.vin[0] - lvars[backward] = lvars[u] - if uop == UOps.STORE: - element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype) - def store_op(): bb[-1].store(element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True)) - if len(vin) > 3: - with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): store_op() - else: store_op() - if uop == UOps.ALU: - lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) - if uop == UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype) + for u in uops: + uop, dtype, vin, args = u.uop, u.dtype, u.vin, u.arg + if uop == UOps.LOOP: + bb.append( + ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")) + ) + bb[-2].branch(bb[-1]._block) - bb[-1].ret_void() - return str(module), {} + phis = [] + for rp in reduce_phis: + incoming = lvars[rp] + lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) + lvars[rp].add_incoming(incoming, bb[-2]._block) + phis.append((rp, lvars[rp])) + + lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") + lvars[u].add_incoming(lvars[vin[0]], bb[-2]._block) + loop_blocks.append((bb[-1], phis)) + if uop == UOps.END: + block, phis = loop_blocks.pop() + idx_p1 = bb[-1].add(lvars[vin[0]], ir.Constant(ir.IntType(32), 1)) + lvars[vin[0]].add_incoming(idx_p1, bb[-1]._block) + for n, phi in phis: + phi.add_incoming(lvars[n], bb[-1]._block) + bb.append( + ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")) + ) + bb[-2].cbranch( + bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), + block._block, + bb[-1]._block, + ) + if uop == UOps.DEFINE_GLOBAL: + lvars[u] = func.args[buf_index[args[0]]] + if uop == UOps.DEFINE_ACC: + lvars[u] = const(args, dtype) + reduce_phis.append(u) + if uop == UOps.SPECIAL: + lvars[u] = lvars[args.expr] + if uop == UOps.CONST: + lvars[u] = const(args, dtype) + if uop == UOps.LOAD: + assert dtype is not None + if len(vin) > 2: + gate = bb[-1].trunc(lvars[vin[2]], ir.IntType(1)) + aug_idx = bb[-1].select( + gate, lvars[vin[1]], ir.Constant(ir.IntType(32), 0) + ) + val = bb[-1].load(bb[-1].gep(lvars[vin[0]], [aug_idx], inbounds=True)) + val = cast(bb, val, vin[0].dtype, dtype) + val = bb[-1].select(gate, val, lvars[vin[3]]) + else: + val = bb[-1].load( + bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True) + ) + val = cast(bb, val, vin[0].dtype, dtype) + lvars[u] = val + if uop == UOps.PHI: + lvars[u] = lvars[vin[1]] + # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC + backward = vin[0] + while backward.uop == UOps.PHI: + backward = backward.vin[0] + lvars[backward] = lvars[u] + if uop == UOps.STORE: + element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype) + + def store_op(): + bb[-1].store( + element, bb[-1].gep(lvars[vin[0]], [lvars[vin[1]]], inbounds=True) + ) + + if len(vin) > 3: + with bb[-1].if_then(bb[-1].trunc(lvars[vin[3]], ir.IntType(1))): + store_op() + else: + store_op() + if uop == UOps.ALU: + lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) + if uop == UOps.CAST: + lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype) + + bb[-1].ret_void() + return str(module), {} diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index d9293022a..73ff230a2 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -5,24 +5,43 @@ from tinygrad.helpers import diskcache, cpu_time_execution from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage -CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' +CLANG_PROGRAM_HEADER = "#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n" + @diskcache -def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes: - # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here - with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c -lm -fPIC --rtlib=compiler-rt - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8')) - return pathlib.Path(output_file.name).read_bytes() +def compile_clang(prg: str, header: str = CLANG_PROGRAM_HEADER) -> bytes: + # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output( + args=( + "clang -shared -O2 -Wall -Werror -x c -lm -fPIC --rtlib=compiler-rt - -o " + + str(output_file.name) + ).split(), + input=(header + prg).encode("utf-8"), + ) + return pathlib.Path(output_file.name).read_bytes() + class ClangProgram: - def __init__(self, name:str, lib:bytes): - self.name, self.lib = name, lib - # write to disk so we can load it - with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: - pathlib.Path(cached_file_path.name).write_bytes(lib) - self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] + def __init__(self, name: str, lib: bytes): + self.name, self.lib = name, lib + # write to disk so we can load it + with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: + pathlib.Path(cached_file_path.name).write_bytes(lib) + self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] - def __call__(self, *bufs, vals=(), wait=False): return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait) + def __call__(self, *bufs, vals=(), wait=False): + return cpu_time_execution(lambda: self.fxn(*bufs, *vals), enable=wait) -renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int")) -ClangDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram) + +renderer = functools.partial( + uops_to_cstyle, + CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"), +) +ClangDevice = Compiled( + MallocAllocator, + LinearizerOptions(supports_float4=False, has_local=False), + renderer, + compile_clang, + ClangProgram, +) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index e4c0c56c0..4dd0293f6 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,48 +1,135 @@ import numpy as np from typing import Callable, Dict, Tuple from tinygrad.helpers import dtypes, flat_mv -from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op +from tinygrad.ops import ( + BufferOps, + UnaryOps, + BinaryOps, + MovementOps, + ReduceOps, + TernaryOps, + Op, +) from tinygrad.device import Interpreted, Allocator -def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]: - assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" - return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b) + +def shape_to_axis( + old_shape: Tuple[int, ...], new_shape: Tuple[int, ...] +) -> Tuple[int, ...]: + assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" + return tuple(i for i, (a, b) in enumerate(zip(old_shape, new_shape)) if a != b) + # TODO: this should be global infrastructure -def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype +def output_type(x, y): + return ( + x.dtype + if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority + else y.dtype + ) + + def match_types(x, y): - up = output_type(x, y) - return x.astype(up, copy=False), y.astype(up, copy=False) + up = output_type(x, y) + return x.astype(up, copy=False), y.astype(up, copy=False) + def einsum_mulacc(einsum, get_strides, expand): - def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) - def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)]) - def mulacc(a, b, new_shape): - (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b)) - out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)] - ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices]) - return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape) - return mulacc + def einscripts(x): + return "".join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x]) + + def axes_slice(strides): + return [i for i, s in enumerate(strides) if s != 0], tuple( + [slice(None) if s != 0 else 0 for i, s in enumerate(strides)] + ) + + def mulacc(a, b, new_shape): + (a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice( + get_strides(b) + ) + out = [ + i + for i in range(len(new_shape)) + if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes) + ] + ret = einsum( + f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", + a[a_slices], + b[b_slices], + ) + return expand( + ret.reshape( + [ + (1 if i not in a_axes and i not in b_axes else s) + for i, s in enumerate(new_shape) + ] + ), + new_shape, + ) + + return mulacc + numpy_fxn_for_op: Dict[Op, Callable] = { - BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np), - UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, - UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x), - BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x memoryview: return flat_mv(np.require(src, requirements='C').data) - def copyin(self, dest:np.ndarray, src:memoryview): np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape)) - def copyout(self, dest:memoryview, src:np.ndarray): np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src) + def _alloc(self, size: int): + return np.empty(size, dtype=np.uint8) + + def as_buffer(self, src: np.ndarray) -> memoryview: + return flat_mv(np.require(src, requirements="C").data) + + def copyin(self, dest: np.ndarray, src: memoryview): + np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape)) + + def copyout(self, dest: memoryview, src: np.ndarray): + np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src) + CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 92857411b..7d5d6248f 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,80 +3,234 @@ import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools from pathlib import Path from typing import Tuple, Optional import gpuctypes.cuda as cuda -from tinygrad.helpers import DEBUG, getenv, diskcache, from_mv, init_c_var, pretty_ptx, cpu_time_execution, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style +from tinygrad.helpers import ( + DEBUG, + getenv, + diskcache, + from_mv, + init_c_var, + pretty_ptx, + cpu_time_execution, + compile_cuda_style, + encode_args_cuda_style, + time_execution_cuda_style, +) from tinygrad.device import Compiled, LRUAllocator, MallocAllocator from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import CUDARenderer CUDACPU = getenv("CUDACPU") == 1 if CUDACPU: - gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot")) - gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] - cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) + gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot")) + gpuocelot_lib.ptx_run.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run( + src, + len(args), + (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), + lx, + ly, + lz, + gx, + gy, + gz, + shared, + ) + def check(status): - if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") + if status != 0: + raise RuntimeError( + f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}" + ) + + +def cu_time_execution(cb, enable=False) -> Optional[float]: + return ( + time_execution_cuda_style( + cb, + cuda.CUevent, + cuda.cuEventCreate, + cuda.cuEventRecord, + cuda.cuEventSynchronize, + cuda.cuEventDestroy_v2, + cuda.cuEventElapsedTime, + enable=enable, + ) + if not CUDACPU + else cpu_time_execution(cb, enable=enable) + ) -def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) @diskcache -def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architecture={CUDADevice.default_arch_name}', "-I/usr/local/cuda/include", "-I/usr/include"], cuda.nvrtcProgram, cuda.nvrtcCreateProgram, cuda.nvrtcCompileProgram, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check) +def compile_cuda(prg) -> bytes: + return compile_cuda_style( + prg, + [ + f"--gpu-architecture={CUDADevice.default_arch_name}", + "-I/usr/local/cuda/include", + "-I/usr/include", + ], + cuda.nvrtcProgram, + cuda.nvrtcCreateProgram, + cuda.nvrtcCompileProgram, + cuda.nvrtcGetPTX, + cuda.nvrtcGetPTXSize, + cuda.nvrtcGetProgramLog, + cuda.nvrtcGetProgramLogSize, + check, + ) + class CUDAProgram: - def __init__(self, device:CUDADevice, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib - if DEBUG >= 5: print(pretty_ptx(lib.decode('utf-8'))) - if DEBUG >= 6: - try: - fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix() - with open(fn + ".ptx", "wb") as f: f.write(lib) - subprocess.run(["ptxas", f"-arch={CUDADevice.default_arch_name}", "-o", fn, fn+".ptx"], check=True) - print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8')) - except Exception as e: print("failed to generate SASS", str(e)) + def __init__(self, device: CUDADevice, name: str, lib: bytes): + self.device, self.name, self.lib = device, name, lib + if DEBUG >= 5: + print(pretty_ptx(lib.decode("utf-8"))) + if DEBUG >= 6: + try: + fn = ( + Path(tempfile.gettempdir()) + / f"tinycuda_{hashlib.md5(lib).hexdigest()}" + ).as_posix() + with open(fn + ".ptx", "wb") as f: + f.write(lib) + subprocess.run( + [ + "ptxas", + f"-arch={CUDADevice.default_arch_name}", + "-o", + fn, + fn + ".ptx", + ], + check=True, + ) + print(subprocess.check_output(["nvdisasm", fn]).decode("utf-8")) + except Exception as e: + print("failed to generate SASS", str(e)) - if not CUDACPU: - check(cuda.cuCtxSetCurrent(self.device.context)) - self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib))) - check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) - self.prg = prg if not CUDACPU else lib + if not CUDACPU: + check(cuda.cuCtxSetCurrent(self.device.context)) + self.module = init_c_var( + cuda.CUmodule(), + lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), lib)), + ) + check( + cuda.cuModuleGetFunction( + ctypes.byref(prg := cuda.CUfunction()), + self.module, + name.encode("utf-8"), + ) + ) + self.prg = prg if not CUDACPU else lib - def __del__(self): - if not CUDACPU: check(cuda.cuModuleUnload(self.module)) + def __del__(self): + if not CUDACPU: + check(cuda.cuModuleUnload(self.module)) + + def __call__( + self, + *bufs, + global_size: Tuple[int, int, int], + local_size: Tuple[int, int, int], + vals: Tuple[int, ...] = (), + wait=False, + ): + if not CUDACPU: + check(cuda.cuCtxSetCurrent(self.device.context)) + c_kernel_input_config = ( + encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1, 2, 0))[0] + if not CUDACPU + else (bufs + vals) + ) + return cu_time_execution( + lambda: check( + cuda.cuLaunchKernel( + self.prg, + *global_size, + *local_size, + 0, + None, + None, + c_kernel_input_config, + ) + ), + enable=wait, + ) - def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): - if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) - c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+vals) - return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) class CUDAAllocator(LRUAllocator): - def __init__(self, device:CUDADevice): - self.device = device - super().__init__() - def _alloc(self, size): - check(cuda.cuCtxSetCurrent(self.device.context)) - return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size))) - def _free(self, opaque): check(cuda.cuMemFree_v2(opaque)) - def copyin(self, dest, src:memoryview): - check(cuda.cuCtxSetCurrent(self.device.context)) - check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None)) - def copyout(self, dest:memoryview, src): - check(cuda.cuCtxSetCurrent(self.device.context)) - check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest))) + def __init__(self, device: CUDADevice): + self.device = device + super().__init__() + + def _alloc(self, size): + check(cuda.cuCtxSetCurrent(self.device.context)) + return init_c_var( + cuda.CUdeviceptr(), + lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)), + ) + + def _free(self, opaque): + check(cuda.cuMemFree_v2(opaque)) + + def copyin(self, dest, src: memoryview): + check(cuda.cuCtxSetCurrent(self.device.context)) + check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None)) + + def copyout(self, dest: memoryview, src): + check(cuda.cuCtxSetCurrent(self.device.context)) + check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest))) + class CUDADevice(Compiled): - default_arch_name = "sm_35" - def __init__(self, device:str): - device_id = int(device.split(":")[1]) if ":" in device else 0 - if not CUDACPU: - check(cuda.cuInit(0)) - check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), device_id)) - check(cuda.cuCtxCreate_v2(ctypes.byref(context := cuda.CUcontext()), 0, device)) - self.context = context - check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id)) - if device_id == 0: CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}" + default_arch_name = "sm_35" - from tinygrad.features.graph.cuda import CUDAGraph - super().__init__(CUDAAllocator(self) if not CUDACPU else MallocAllocator, - LinearizerOptions(supports_float4_alu=False, global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024]), - CUDARenderer, compile_cuda, functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None) - def synchronize(self): return check(cuda.cuCtxSynchronize()) if not CUDACPU else None + def __init__(self, device: str): + device_id = int(device.split(":")[1]) if ":" in device else 0 + if not CUDACPU: + check(cuda.cuInit(0)) + check(cuda.cuDeviceGet(ctypes.byref(device := cuda.CUdevice()), device_id)) + check( + cuda.cuCtxCreate_v2( + ctypes.byref(context := cuda.CUcontext()), 0, device + ) + ) + self.context = context + check( + cuda.cuDeviceComputeCapability( + ctypes.byref(major := ctypes.c_int()), + ctypes.byref(minor := ctypes.c_int()), + device_id, + ) + ) + if device_id == 0: + CUDADevice.default_arch_name = f"sm_{major.value}{minor.value}" + + from tinygrad.features.graph.cuda import CUDAGraph + + super().__init__( + CUDAAllocator(self) if not CUDACPU else MallocAllocator, + LinearizerOptions( + supports_float4_alu=False, + global_max=[65535, 65535, 2147483647], + local_max=[64, 1024, 1024], + ), + CUDARenderer, + compile_cuda, + functools.partial(CUDAProgram, self), + graph=CUDAGraph if not CUDACPU else None, + ) + + def synchronize(self): + return check(cuda.cuCtxSynchronize()) if not CUDACPU else None diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 8fbb4a9ee..b64b0c7ec 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,57 +1,100 @@ import os, mmap -try: import _posixshmem -except Exception: pass + +try: + import _posixshmem +except Exception: + pass from typing import Callable, Dict, Tuple from tinygrad.helpers import prod, DType, OSX, dtypes from tinygrad.device import Interpreted, Allocator from tinygrad.ops import Op, MovementOps, UnaryOps from tinygrad.shape.view import strides_for_shape + class UnderlyingDiskBuffer: - def __init__(self, fd, mem): self.fd, self.mem = fd, mem - def __del__(self): - if self.fd: self.fd.close() + def __init__(self, fd, mem): + self.fd, self.mem = fd, mem + + def __del__(self): + if self.fd: + self.fd.close() + class DiskBuffer: - def __init__(self, ud:UnderlyingDiskBuffer, size:int, dtype:DType=dtypes.uint8, offset=0): self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset - def __repr__(self): return f"" - def cast(self, arg:Tuple[DType, bool]): return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset) - def as_strided(self, arg): - assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides" - return DiskBuffer(self.ud, prod(arg[0]), self.dtype, offset=self.offset+arg[2]*self.dtype.itemsize) - def _buf(self) -> memoryview: return memoryview(self.ud.mem)[self.offset:self.offset+self.size*self.dtype.itemsize] + def __init__( + self, ud: UnderlyingDiskBuffer, size: int, dtype: DType = dtypes.uint8, offset=0 + ): + self.ud, self.size, self.dtype, self.offset = ud, size, dtype, offset -disk_fxn_for_op: Dict[Op, Callable] = { UnaryOps.CAST: DiskBuffer.cast, MovementOps.AS_STRIDED: DiskBuffer.as_strided } + def __repr__(self): + return f"" + + def cast(self, arg: Tuple[DType, bool]): + return DiskBuffer(self.ud, self.size, arg[0], offset=self.offset) + + def as_strided(self, arg): + assert strides_for_shape(arg[0]) == arg[1], "disk tensors don't support strides" + return DiskBuffer( + self.ud, + prod(arg[0]), + self.dtype, + offset=self.offset + arg[2] * self.dtype.itemsize, + ) + + def _buf(self) -> memoryview: + return memoryview(self.ud.mem)[ + self.offset : self.offset + self.size * self.dtype.itemsize + ] + + +disk_fxn_for_op: Dict[Op, Callable] = { + UnaryOps.CAST: DiskBuffer.cast, + MovementOps.AS_STRIDED: DiskBuffer.as_strided, +} MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000 + + class DiskAllocator(Allocator): - def __init__(self, device): self.device = device - def _alloc(self, size): - if str(self.device).startswith("shm:"): - if OSX: - with open(f"/tmp/shm_{self.device[4:]}", "w+b") as f: - f.truncate(size) - shm = mmap.mmap(f.fileno(), size, flags=mmap.MAP_SHARED) - else: - fd = _posixshmem.shm_open(self.device[4:], os.O_RDWR, 0o600) - # TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need - shm = mmap.mmap(fd, size, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE) - shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX - os.close(fd) - buf = UnderlyingDiskBuffer(None, shm) - else: - f = open(self.device, "a+b") - if os.path.getsize(self.device) < size: os.ftruncate(f.fileno(), size) - buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size)) - return DiskBuffer(buf, size) - def as_buffer(self, src:DiskBuffer): return src._buf() - def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src - def copyout(self, dest:memoryview, src:DiskBuffer): - if src.ud.fd is not None: - src.ud.fd.seek(src.offset) - src.ud.fd.readinto(dest) - else: - dest[:] = src._buf() + def __init__(self, device): + self.device = device + + def _alloc(self, size): + if str(self.device).startswith("shm:"): + if OSX: + with open(f"/tmp/shm_{self.device[4:]}", "w+b") as f: + f.truncate(size) + shm = mmap.mmap(f.fileno(), size, flags=mmap.MAP_SHARED) + else: + fd = _posixshmem.shm_open(self.device[4:], os.O_RDWR, 0o600) + # TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need + shm = mmap.mmap( + fd, size, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE + ) + shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX + os.close(fd) + buf = UnderlyingDiskBuffer(None, shm) + else: + f = open(self.device, "a+b") + if os.path.getsize(self.device) < size: + os.ftruncate(f.fileno(), size) + buf = UnderlyingDiskBuffer(f, mmap.mmap(f.fileno(), size)) + return DiskBuffer(buf, size) + + def as_buffer(self, src: DiskBuffer): + return src._buf() + + def copyin(self, dest: DiskBuffer, src: memoryview): + dest._buf()[:] = src + + def copyout(self, dest: memoryview, src: DiskBuffer): + if src.ud.fd is not None: + src.ud.fd.seek(src.offset) + src.ud.fd.readinto(dest) + else: + dest[:] = src._buf() + class DiskDevice(Interpreted): - def __init__(self, device): super().__init__(DiskAllocator(device[5:]), disk_fxn_for_op) \ No newline at end of file + def __init__(self, device): + super().__init__(DiskAllocator(device[5:]), disk_fxn_for_op) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 39e4cbb4e..a6374d17a 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -2,99 +2,351 @@ from __future__ import annotations from typing import Tuple, Optional, List import ctypes, functools import gpuctypes.opencl as cl -from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType, DEBUG +from tinygrad.helpers import ( + init_c_var, + to_char_p_p, + from_mv, + diskcache, + OSX, + ImageDType, + DEBUG, +) from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.device import Compiled, LRUAllocator -OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something +OSX_TIMING_RATIO = ( + (125 / 3) if OSX else 1.0 +) # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something + def check(status): - if status != 0: raise RuntimeError(f"OpenCL Error {status}") -def checked(ret, status): return (check(status.value), ret)[1] + if status != 0: + raise RuntimeError(f"OpenCL Error {status}") + + +def checked(ret, status): + return (check(status.value), ret)[1] + @diskcache -def compile_cl(prg:str) -> bytes: - assert CLDevice.compiler_context is not None, 'OpenCL requires a "compiler_context" to compile, init a device before you call this' - program = checked(cl.clCreateProgramWithSource(CLDevice.compiler_context.context, 1, to_char_p_p([prg_bytes := prg.encode()]), ctypes.byref(ctypes.c_size_t(len(prg_bytes))), ctypes.byref(status := ctypes.c_int32())), status) - status = cl.clBuildProgram(program, 1, ctypes.byref(CLDevice.compiler_context.device_id), None, cl.clBuildProgram.argtypes[4](), None) - if status != 0: - cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, ctypes.byref(log_size := ctypes.c_size_t())) - cl.clGetProgramBuildInfo(program, CLDevice.compiler_context.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) - raise RuntimeError(f"OpenCL Compile Error\n\n{ctypes.string_at(mstr, size=log_size.value).decode()}") - binary_sizes = init_c_var((ctypes.c_size_t * 1)(), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(x), ctypes.byref(x), None))) - binary = init_c_var(ctypes.create_string_buffer(binary_sizes[0]), lambda x: check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), None))) - check(cl.clReleaseProgram(program)) - return bytes(binary) +def compile_cl(prg: str) -> bytes: + assert ( + CLDevice.compiler_context is not None + ), 'OpenCL requires a "compiler_context" to compile, init a device before you call this' + program = checked( + cl.clCreateProgramWithSource( + CLDevice.compiler_context.context, + 1, + to_char_p_p([prg_bytes := prg.encode()]), + ctypes.byref(ctypes.c_size_t(len(prg_bytes))), + ctypes.byref(status := ctypes.c_int32()), + ), + status, + ) + status = cl.clBuildProgram( + program, + 1, + ctypes.byref(CLDevice.compiler_context.device_id), + None, + cl.clBuildProgram.argtypes[4](), + None, + ) + if status != 0: + cl.clGetProgramBuildInfo( + program, + CLDevice.compiler_context.device_id, + cl.CL_PROGRAM_BUILD_LOG, + 0, + None, + ctypes.byref(log_size := ctypes.c_size_t()), + ) + cl.clGetProgramBuildInfo( + program, + CLDevice.compiler_context.device_id, + cl.CL_PROGRAM_BUILD_LOG, + log_size.value, + mstr := ctypes.create_string_buffer(log_size.value), + None, + ) + raise RuntimeError( + f"OpenCL Compile Error\n\n{ctypes.string_at(mstr, size=log_size.value).decode()}" + ) + binary_sizes = init_c_var( + (ctypes.c_size_t * 1)(), + lambda x: check( + cl.clGetProgramInfo( + program, + cl.CL_PROGRAM_BINARY_SIZES, + ctypes.sizeof(x), + ctypes.byref(x), + None, + ) + ), + ) + binary = init_c_var( + ctypes.create_string_buffer(binary_sizes[0]), + lambda x: check( + cl.clGetProgramInfo( + program, + cl.CL_PROGRAM_BINARIES, + ctypes.sizeof(ctypes.c_void_p), + ctypes.byref((ctypes.c_void_p * 1)(ctypes.addressof(x))), + None, + ) + ), + ) + check(cl.clReleaseProgram(program)) + return bytes(binary) + class CLProgram: - def __init__(self, device:CLDevice, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib - self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, ctypes.byref(device.device_id), (ctypes.c_size_t * 1)(len(lib)), to_char_p_p([lib], ctypes.c_ubyte), - ctypes.byref(binary_status := ctypes.c_int32()), ctypes.byref(errcode_ret := ctypes.c_int32())), errcode_ret) - check(binary_status.value) - check(cl.clBuildProgram(self.program, 1, ctypes.byref(device.device_id), None, cl.clBuildProgram.argtypes[4](), None)) # NOTE: OSX requires this - self.kernel = checked(cl.clCreateKernel(self.program, name.encode(), ctypes.byref(status := ctypes.c_int32())), status) + def __init__(self, device: CLDevice, name: str, lib: bytes): + self.device, self.name, self.lib = device, name, lib + self.program = checked( + cl.clCreateProgramWithBinary( + device.context, + 1, + ctypes.byref(device.device_id), + (ctypes.c_size_t * 1)(len(lib)), + to_char_p_p([lib], ctypes.c_ubyte), + ctypes.byref(binary_status := ctypes.c_int32()), + ctypes.byref(errcode_ret := ctypes.c_int32()), + ), + errcode_ret, + ) + check(binary_status.value) + check( + cl.clBuildProgram( + self.program, + 1, + ctypes.byref(device.device_id), + None, + cl.clBuildProgram.argtypes[4](), + None, + ) + ) # NOTE: OSX requires this + self.kernel = checked( + cl.clCreateKernel( + self.program, name.encode(), ctypes.byref(status := ctypes.c_int32()) + ), + status, + ) - def __del__(self): - check(cl.clReleaseKernel(self.kernel)) - check(cl.clReleaseProgram(self.program)) + def __del__(self): + check(cl.clReleaseKernel(self.kernel)) + check(cl.clReleaseProgram(self.program)) + + def __call__( + self, + *bufs: cl.cl_mem, + global_size: Tuple[int, ...], + local_size: Optional[Tuple[int, ...]] = None, + vals: Tuple[int, ...] = (), + wait=False, + ) -> Optional[float]: + for i, b in enumerate(bufs): + cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)) + for i, b in enumerate(vals, start=len(bufs)): + cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(b))) + if local_size is not None: + global_size = tuple(int(g * l) for g, l in zip(global_size, local_size)) + event = cl.cl_event() if wait else None + check( + cl.clEnqueueNDRangeKernel( + self.device.queue, + self.kernel, + len(global_size), + None, + (ctypes.c_size_t * len(global_size))(*global_size), + (ctypes.c_size_t * len(local_size))(*local_size) + if local_size + else None, + 0, + None, + event, + ) + ) + if wait: + check(cl.clWaitForEvents(1, ctypes.byref(event))) + start = init_c_var( + ctypes.c_ulong(), + lambda x: check( + cl.clGetEventProfilingInfo( + event, + cl.CL_PROFILING_COMMAND_START, + ctypes.sizeof(x), + ctypes.byref(x), + None, + ) + ), + ) + end = init_c_var( + ctypes.c_ulong(), + lambda x: check( + cl.clGetEventProfilingInfo( + event, + cl.CL_PROFILING_COMMAND_END, + ctypes.sizeof(x), + ctypes.byref(x), + None, + ) + ), + ) + return float(end.value - start.value) * OSX_TIMING_RATIO * 1e-9 + return None - def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: - for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)) - for i,b in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(b))) - if local_size is not None: global_size = tuple(int(g*l) for g,l in zip(global_size, local_size)) - event = cl.cl_event() if wait else None - check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event)) - if wait: - check(cl.clWaitForEvents(1, ctypes.byref(event))) - start = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, ctypes.sizeof(x), ctypes.byref(x), None))) - end = init_c_var(ctypes.c_ulong(), lambda x: check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, ctypes.sizeof(x), ctypes.byref(x), None))) - return float(end.value-start.value) * OSX_TIMING_RATIO * 1e-9 - return None class CLAllocator(LRUAllocator): - def __init__(self, device:CLDevice): - self.device = device - super().__init__() - def _alloc(self, size:int) -> cl.cl_mem: - return checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, ctypes.byref(status := ctypes.c_int32())), status) - def _free(self, buf:cl.cl_mem): check(cl.clReleaseMemObject(buf)) - def _cast_image(self, buf:cl.cl_mem, dtype:ImageDType, row_pitch:int) -> cl.cl_mem: - desc = cl.cl_image_desc(image_type=cl.CL_MEM_OBJECT_IMAGE2D, image_width=dtype.shape[1], image_height=dtype.shape[0], image_row_pitch=row_pitch) - desc._0.mem_object = buf - return checked(cl.clCreateImage(self.device.context, cl.CL_MEM_READ_WRITE, - cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[dtype.itemsize]), - desc, None, ctypes.byref(status := ctypes.c_int32())), status) - def copyin(self, dest:cl.cl_mem, src:memoryview): - check(cl.clEnqueueWriteBuffer(self.device.queue, dest, False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None)) - self.device.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command - def copyout(self, dest:memoryview, src:cl.cl_mem): - check(cl.clEnqueueReadBuffer(self.device.queue, src, False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None)) - self.device.synchronize() + def __init__(self, device: CLDevice): + self.device = device + super().__init__() + + def _alloc(self, size: int) -> cl.cl_mem: + return checked( + cl.clCreateBuffer( + self.device.context, + cl.CL_MEM_READ_WRITE, + size, + None, + ctypes.byref(status := ctypes.c_int32()), + ), + status, + ) + + def _free(self, buf: cl.cl_mem): + check(cl.clReleaseMemObject(buf)) + + def _cast_image( + self, buf: cl.cl_mem, dtype: ImageDType, row_pitch: int + ) -> cl.cl_mem: + desc = cl.cl_image_desc( + image_type=cl.CL_MEM_OBJECT_IMAGE2D, + image_width=dtype.shape[1], + image_height=dtype.shape[0], + image_row_pitch=row_pitch, + ) + desc._0.mem_object = buf + return checked( + cl.clCreateImage( + self.device.context, + cl.CL_MEM_READ_WRITE, + cl.cl_image_format( + cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[dtype.itemsize] + ), + desc, + None, + ctypes.byref(status := ctypes.c_int32()), + ), + status, + ) + + def copyin(self, dest: cl.cl_mem, src: memoryview): + check( + cl.clEnqueueWriteBuffer( + self.device.queue, + dest, + False, + 0, + len(src) * src.itemsize, + from_mv(src), + 0, + None, + None, + ) + ) + self.device.pending_copyin.append( + src + ) # NOTE: these can't be freed until the GPU actually executes this command + + def copyout(self, dest: memoryview, src: cl.cl_mem): + check( + cl.clEnqueueReadBuffer( + self.device.queue, + src, + False, + 0, + len(dest) * dest.itemsize, + from_mv(dest), + 0, + None, + None, + ) + ) + self.device.synchronize() + class CLDevice(Compiled): - device_ids = None # this is global and only initted once - compiler_context = None # this is the first created context. we make an assumption they are all the same for the compiler - def __init__(self, device:str=""): - if CLDevice.device_ids is None: - num_platforms = init_c_var(ctypes.c_uint32(), lambda x: check(cl.clGetPlatformIDs(0, None, ctypes.byref(x)))) - platform_ids = init_c_var((cl.cl_platform_id * num_platforms.value)(), lambda x: check(cl.clGetPlatformIDs(num_platforms.value, x, None))) - for device_type in [cl.CL_DEVICE_TYPE_GPU, cl.CL_DEVICE_TYPE_DEFAULT]: - num_devices = ctypes.c_uint32() - err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, ctypes.byref(num_devices)) - if err == 0 and num_devices.value != 0: break - if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices") - CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None))) + device_ids = None # this is global and only initted once + compiler_context = None # this is the first created context. we make an assumption they are all the same for the compiler - self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])] - self.context = checked(cl.clCreateContext(None, 1, ctypes.byref(self.device_id), cl.clCreateContext.argtypes[3](), None, ctypes.byref(status := ctypes.c_int32())), status) - if CLDevice.compiler_context is None: CLDevice.compiler_context = self - self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, ctypes.byref(status)), status) - self.pending_copyin: List[memoryview] = [] - super().__init__(CLAllocator(self), LinearizerOptions(), OpenCLRenderer, compile_cl, functools.partial(CLProgram, self)) - def synchronize(self): - check(cl.clFinish(self.queue)) - self.pending_copyin.clear() + def __init__(self, device: str = ""): + if CLDevice.device_ids is None: + num_platforms = init_c_var( + ctypes.c_uint32(), + lambda x: check(cl.clGetPlatformIDs(0, None, ctypes.byref(x))), + ) + platform_ids = init_c_var( + (cl.cl_platform_id * num_platforms.value)(), + lambda x: check(cl.clGetPlatformIDs(num_platforms.value, x, None)), + ) + for device_type in [cl.CL_DEVICE_TYPE_GPU, cl.CL_DEVICE_TYPE_DEFAULT]: + num_devices = ctypes.c_uint32() + err = cl.clGetDeviceIDs( + platform_ids[0], device_type, 0, None, ctypes.byref(num_devices) + ) + if err == 0 and num_devices.value != 0: + break + if DEBUG >= 1: + print( + f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices" + ) + CLDevice.device_ids = init_c_var( + (cl.cl_device_id * num_devices.value)(), + lambda x: check( + cl.clGetDeviceIDs( + platform_ids[0], device_type, num_devices, x, None + ) + ), + ) -GPUDevice = CLDevice # for legacy reasons + self.device_id = CLDevice.device_ids[ + 0 if ":" not in device else int(device.split(":")[1]) + ] + self.context = checked( + cl.clCreateContext( + None, + 1, + ctypes.byref(self.device_id), + cl.clCreateContext.argtypes[3](), + None, + ctypes.byref(status := ctypes.c_int32()), + ), + status, + ) + if CLDevice.compiler_context is None: + CLDevice.compiler_context = self + self.queue = checked( + cl.clCreateCommandQueue( + self.context, + self.device_id, + cl.CL_QUEUE_PROFILING_ENABLE, + ctypes.byref(status), + ), + status, + ) + self.pending_copyin: List[memoryview] = [] + super().__init__( + CLAllocator(self), + LinearizerOptions(), + OpenCLRenderer, + compile_cl, + functools.partial(CLProgram, self), + ) + + def synchronize(self): + check(cl.clFinish(self.queue)) + self.pending_copyin.clear() + + +GPUDevice = CLDevice # for legacy reasons diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 87272dbb1..70edc2deb 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -1,68 +1,182 @@ import ctypes, functools, subprocess from typing import Tuple, TypeVar import gpuctypes.hip as hip -from tinygrad.helpers import DEBUG, getenv, diskcache, from_mv, init_c_var, compile_cuda_style, encode_args_cuda_style, time_execution_cuda_style +from tinygrad.helpers import ( + DEBUG, + getenv, + diskcache, + from_mv, + init_c_var, + compile_cuda_style, + encode_args_cuda_style, + time_execution_cuda_style, +) from tinygrad.device import Compiled, LRUAllocator, MallocAllocator from tinygrad.renderer.cstyle import HIPRenderer from tinygrad.codegen.kernel import LinearizerOptions # The default HIP stream is used for everything. -MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile +MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile + def check(status): - if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}") + if status != 0: + raise RuntimeError( + f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}" + ) + + +def hip_time_execution(cb, enable=False): + return time_execution_cuda_style( + cb, + hip.hipEvent_t, + hip.hipEventCreate, + hip.hipEventRecord, + hip.hipEventSynchronize, + hip.hipEventDestroy, + hip.hipEventElapsedTime, + enable=enable, + ) -def hip_time_execution(cb, enable=False): return time_execution_cuda_style(cb, hip.hipEvent_t, hip.hipEventCreate, hip.hipEventRecord, hip.hipEventSynchronize, hip.hipEventDestroy, hip.hipEventElapsedTime, enable=enable) @diskcache -def compile_hip(prg) -> bytes: return compile_cuda_style(prg, [f'--offload-arch={HIPDevice.default_arch_name}'], hip.hiprtcProgram, hip.hiprtcCreateProgram, hip.hiprtcCompileProgram, hip.hiprtcGetCode, hip.hiprtcGetCodeSize, hip.hiprtcGetProgramLog, hip.hiprtcGetProgramLogSize, check) +def compile_hip(prg) -> bytes: + return compile_cuda_style( + prg, + [f"--offload-arch={HIPDevice.default_arch_name}"], + hip.hiprtcProgram, + hip.hiprtcCreateProgram, + hip.hiprtcCompileProgram, + hip.hiprtcGetCode, + hip.hiprtcGetCodeSize, + hip.hiprtcGetProgramLog, + hip.hiprtcGetProgramLogSize, + check, + ) + class HIPProgram: - def __init__(self, device:int, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib + def __init__(self, device: int, name: str, lib: bytes): + self.device, self.name, self.lib = device, name, lib - if DEBUG >= 6: - asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib) - print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) + if DEBUG >= 6: + asm = subprocess.check_output( + ["/opt/rocm/llvm/bin/llvm-objdump", "-d", "-"], input=lib + ) + print( + "\n".join( + [ + x + for x in asm.decode("utf-8").split("\n") + if "s_code_end" not in x + ] + ) + ) - if MOCKHIP: return - check(hip.hipSetDevice(self.device)) - self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib))) - self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8")))) + if MOCKHIP: + return + check(hip.hipSetDevice(self.device)) + self.module = init_c_var( + hip.hipModule_t(), + lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)), + ) + self.prg = init_c_var( + hip.hipFunction_t(), + lambda x: check( + hip.hipModuleGetFunction( + ctypes.byref(x), self.module, name.encode("utf-8") + ) + ), + ) - def __del__(self): - if not MOCKHIP: check(hip.hipModuleUnload(self.module)) + def __del__(self): + if not MOCKHIP: + check(hip.hipModuleUnload(self.module)) + + def __call__( + self, + *args, + global_size: Tuple[int, int, int], + local_size: Tuple[int, int, int], + vals: Tuple[int, ...] = (), + wait=False, + ): + if MOCKHIP: + return float("inf") + check(hip.hipSetDevice(self.device)) + return hip_time_execution( + lambda: check( + hip.hipModuleLaunchKernel( + self.prg, + *global_size, + *local_size, + 0, + None, + None, + encode_args_cuda_style( + args, vals, hip.hipDeviceptr_t, marks=(1, 2, 3) + )[0], + ) + ), + enable=wait, + ) - def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): - if MOCKHIP: return float("inf") - check(hip.hipSetDevice(self.device)) - return hip_time_execution(lambda: check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, encode_args_cuda_style(args, vals, hip.hipDeviceptr_t, marks=(1,2,3))[0])), enable=wait) T = TypeVar("T") + + class HIPAllocator(LRUAllocator): - def __init__(self, device): - self.device = device - super().__init__() - def _alloc(self, size:int): - check(hip.hipSetDevice(self.device)) - return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size))) - def _free(self, opaque:T): check(hip.hipFree(opaque)) - def copyin(self, dest:T, src: memoryview): - check(hip.hipSetDevice(self.device)) - check(hip.hipMemcpyAsync(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice, None)) - def copyout(self, dest:memoryview, src:T): - check(hip.hipSetDevice(self.device)) - check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)) - def transfer(self, dest:T, src:T, sz:int): - check(hip.hipSetDevice(self.device)) - check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice)) + def __init__(self, device): + self.device = device + super().__init__() + + def _alloc(self, size: int): + check(hip.hipSetDevice(self.device)) + return init_c_var( + hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)) + ) + + def _free(self, opaque: T): + check(hip.hipFree(opaque)) + + def copyin(self, dest: T, src: memoryview): + check(hip.hipSetDevice(self.device)) + check( + hip.hipMemcpyAsync( + dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice, None + ) + ) + + def copyout(self, dest: memoryview, src: T): + check(hip.hipSetDevice(self.device)) + check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)) + + def transfer(self, dest: T, src: T, sz: int): + check(hip.hipSetDevice(self.device)) + check(hip.hipMemcpy(dest, src, sz, hip.hipMemcpyDeviceToDevice)) + class HIPDevice(Compiled): - default_arch_name = "gfx1100" - def __init__(self, device:str=""): - self.device = int(device.split(":")[1]) if ":" in device else 0 - if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() + default_arch_name = "gfx1100" - from tinygrad.features.graph.hip import HIPGraph - super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph) - def synchronize(self): hip.hipDeviceSynchronize() \ No newline at end of file + def __init__(self, device: str = ""): + self.device = int(device.split(":")[1]) if ":" in device else 0 + if self.device == 0 and not MOCKHIP: + HIPDevice.default_arch_name = init_c_var( + hip.hipDeviceProp_t(), + lambda x: check(hip.hipGetDeviceProperties(x, self.device)), + ).gcnArchName.decode() + + from tinygrad.features.graph.hip import HIPGraph + + super().__init__( + MallocAllocator if MOCKHIP else HIPAllocator(self.device), + LinearizerOptions(device="HIP"), + HIPRenderer, + compile_hip, + functools.partial(HIPProgram, self.device), + HIPGraph, + ) + + def synchronize(self): + hip.hipDeviceSynchronize() diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 9f9240265..6a8ebe019 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -10,57 +10,78 @@ import llvmlite.binding as llvm LLVMOPT = bool(getenv("LLVMOPT")) + class LLVM: - target_machine: ClassVar[llvm.targets.TargetMachine] = None - engine: ClassVar[llvm.executionengine.ExecutionEngine] = None - optimizer: ClassVar[llvm.passmanagers.ModulePassManager] = None + target_machine: ClassVar[llvm.targets.TargetMachine] = None + engine: ClassVar[llvm.executionengine.ExecutionEngine] = None + optimizer: ClassVar[llvm.passmanagers.ModulePassManager] = None - def __init__(self): - if LLVM.engine is not None: return - llvm.initialize() - llvm.initialize_native_target() - llvm.initialize_native_asmprinter() - llvm.initialize_native_asmparser() - target = llvm.Target.from_triple(llvm.get_process_triple()) - LLVM.optimizer = llvm.create_module_pass_manager() - LLVM.target_machine = target.create_target_machine(opt=2) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA - LLVM.target_machine.add_analysis_passes(LLVM.optimizer) + def __init__(self): + if LLVM.engine is not None: + return + llvm.initialize() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + llvm.initialize_native_asmparser() + target = llvm.Target.from_triple(llvm.get_process_triple()) + LLVM.optimizer = llvm.create_module_pass_manager() + LLVM.target_machine = target.create_target_machine( + opt=2 + ) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA + LLVM.target_machine.add_analysis_passes(LLVM.optimizer) - # TODO: this makes compile times so much faster - if LLVMOPT: - llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed - if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize') - #llvm.set_option(str(), '--debug') + # TODO: this makes compile times so much faster + if LLVMOPT: + llvm.set_option( + str(), "-force-vector-interleave=4" + ) # this makes sum the same speed as torch, it also doubles the (slow) conv speed + if DEBUG >= 4: + llvm.set_option(str(), "--debug-only=loop-vectorize") + # llvm.set_option(str(), '--debug') - # does this do anything? - builder = llvm.create_pass_manager_builder() - builder.opt_level = 3 - builder.size_level = 0 - builder.loop_vectorize = True - builder.slp_vectorize = True - builder.populate(LLVM.optimizer) + # does this do anything? + builder = llvm.create_pass_manager_builder() + builder.opt_level = 3 + builder.size_level = 0 + builder.loop_vectorize = True + builder.slp_vectorize = True + builder.populate(LLVM.optimizer) + + LLVM.target_machine.set_asm_verbosity(True) + backing_mod = llvm.parse_assembly(str()) + backing_mod.triple = llvm.get_process_triple() + LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine) - LLVM.target_machine.set_asm_verbosity(True) - backing_mod = llvm.parse_assembly(str()) - backing_mod.triple = llvm.get_process_triple() - LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine) @diskcache def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes: - mod = llvm.parse_assembly(prg) - mod.verify() - LLVM().optimizer.run(mod) - if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod)) - return LLVM.target_machine.emit_object(mod) + mod = llvm.parse_assembly(prg) + mod.verify() + LLVM().optimizer.run(mod) + if DEBUG >= 5: + print(LLVM.target_machine.emit_assembly(mod)) + return LLVM.target_machine.emit_object(mod) + class LLVMProgram: - def __init__(self, name:str, lib:bytes): - self.name, self.lib = name, lib - LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) - self.fxn = LLVM.engine.get_function_address(name) + def __init__(self, name: str, lib: bytes): + self.name, self.lib = name, lib + LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) + self.fxn = LLVM.engine.get_function_address(name) - def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False): - self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn) - return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait) + def __call__(self, *bufs, vals: Tuple[int, ...] = (), wait=False): + self.cfunc = CFUNCTYPE( + ctypes.c_int, + *([ctypes.c_void_p] * len(bufs)), + *([ctypes.c_int32] * len(vals)) + )(self.fxn) + return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait) -LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram) + +LLVMDevice = Compiled( + MallocAllocator, + LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), + uops_to_llvm_ir, + compile_llvm, + LLVMProgram, +) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 1f8329643..408adfdff 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -7,82 +7,146 @@ from tinygrad.helpers import prod, getenv, DEBUG, diskcache, unwrap2 from tinygrad.device import Compiled, LRUAllocator from tinygrad.renderer.cstyle import MetalRenderer + @diskcache def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes: - assert MetalDevice.compiler_device, "metal device creation is required for metal compile" - if use_xcode: - # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode - air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) - return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) - options = Metal.MTLCompileOptions.new() - library = unwrap2(MetalDevice.compiler_device.newLibraryWithSource_options_error_(prg, options, None)) - return library.libraryDataContents().bytes().tobytes() + assert ( + MetalDevice.compiler_device + ), "metal device creation is required for metal compile" + if use_xcode: + # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode + air = subprocess.check_output( + ["xcrun", "-sdk", "macosx", "metal", "-x", "metal", "-c", "-", "-o", "-"], + input=prg.encode("utf-8"), + ) + return subprocess.check_output( + ["xcrun", "-sdk", "macosx", "metallib", "-", "-o", "-"], input=air + ) + options = Metal.MTLCompileOptions.new() + library = unwrap2( + MetalDevice.compiler_device.newLibraryWithSource_options_error_( + prg, options, None + ) + ) + return library.libraryDataContents().bytes().tobytes() + class MetalProgram: - def __init__(self, device:MetalDevice, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib - if DEBUG >= 6: - with tempfile.NamedTemporaryFile(delete=True) as shader: - shader.write(lib) - shader.flush() - os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}") - data = libdispatch.dispatch_data_create(lib, len(lib), None, None) - self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None)) - self.fxn = self.library.newFunctionWithName_(name) - self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) + def __init__(self, device: MetalDevice, name: str, lib: bytes): + self.device, self.name, self.lib = device, name, lib + if DEBUG >= 6: + with tempfile.NamedTemporaryFile(delete=True) as shader: + shader.write(lib) + shader.flush() + os.system( + f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}" + ) + data = libdispatch.dispatch_data_create(lib, len(lib), None, None) + self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None)) + self.fxn = self.library.newFunctionWithName_(name) + self.pipeline_state = unwrap2( + self.device.device.newComputePipelineStateWithFunction_error_( + self.fxn, None + ) + ) + + def __call__( + self, + *bufs, + global_size: Tuple[int, int, int], + local_size: Tuple[int, int, int], + vals: Tuple[int, ...] = (), + wait=False, + ): + assert ( + prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup() + ), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" + command_buffer = self.device.mtl_queue.commandBuffer() + encoder = command_buffer.computeCommandEncoder() + encoder.setComputePipelineState_(self.pipeline_state) + for i, a in enumerate(bufs): + encoder.setBuffer_offset_atIndex_(a, 0, i) + for i, a in enumerate(vals, start=len(bufs)): + encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i) + encoder.dispatchThreadgroups_threadsPerThreadgroup_( + Metal.MTLSize(*global_size), Metal.MTLSize(*local_size) + ) + encoder.endEncoding() + command_buffer.commit() + if wait: + command_buffer.waitUntilCompleted() + return command_buffer.GPUEndTime() - command_buffer.GPUStartTime() + self.device.mtl_buffers_in_flight.append(command_buffer) - def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): - assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" - command_buffer = self.device.mtl_queue.commandBuffer() - encoder = command_buffer.computeCommandEncoder() - encoder.setComputePipelineState_(self.pipeline_state) - for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a, 0, i) - for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i) - encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) - encoder.endEncoding() - command_buffer.commit() - if wait: - command_buffer.waitUntilCompleted() - return command_buffer.GPUEndTime() - command_buffer.GPUStartTime() - self.device.mtl_buffers_in_flight.append(command_buffer) class MetalAllocator(LRUAllocator): - def __init__(self, device:MetalDevice): - self.device:MetalDevice = device - super().__init__() - def _alloc(self, size:int) -> Any: - ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) - if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}") - return ret - def transfer(self, dest:Any, src:Any, sz:int): - command_buffer = self.device.mtl_queue.commandBuffer() - encoder = command_buffer.blitCommandEncoder() - encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz) - encoder.endEncoding() - command_buffer.commit() - self.device.mtl_buffers_in_flight.append(command_buffer) - def from_buffer(self, src:memoryview) -> Optional[Any]: - ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, len(src), Metal.MTLResourceStorageModeShared, None) - if ret: self.device.mv_in_metal.append(src) - return ret - def _free(self, opaque:Any): opaque.release() - def as_buffer(self, src:Any) -> memoryview: - self.device.synchronize() - return src.contents().as_buffer(src.length()) - def copyin(self, dest:Any, src:memoryview): self.as_buffer(dest)[:] = src - def copyout(self, dest:memoryview, src:Any): dest[:] = self.as_buffer(src) + def __init__(self, device: MetalDevice): + self.device: MetalDevice = device + super().__init__() + + def _alloc(self, size: int) -> Any: + ret = self.device.device.newBufferWithLength_options_( + size, Metal.MTLResourceStorageModeShared + ) + if ret is None: + raise MemoryError(f"Metal OOM while allocating {size=}") + return ret + + def transfer(self, dest: Any, src: Any, sz: int): + command_buffer = self.device.mtl_queue.commandBuffer() + encoder = command_buffer.blitCommandEncoder() + encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_( + src, 0, dest, 0, sz + ) + encoder.endEncoding() + command_buffer.commit() + self.device.mtl_buffers_in_flight.append(command_buffer) + + def from_buffer(self, src: memoryview) -> Optional[Any]: + ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_( + src, len(src), Metal.MTLResourceStorageModeShared, None + ) + if ret: + self.device.mv_in_metal.append(src) + return ret + + def _free(self, opaque: Any): + opaque.release() + + def as_buffer(self, src: Any) -> memoryview: + self.device.synchronize() + return src.contents().as_buffer(src.length()) + + def copyin(self, dest: Any, src: memoryview): + self.as_buffer(dest)[:] = src + + def copyout(self, dest: memoryview, src: Any): + dest[:] = self.as_buffer(src) + class MetalDevice(Compiled): - compiler_device = None - def __init__(self, device:str): - self.device = Metal.MTLCreateSystemDefaultDevice() - if MetalDevice.compiler_device is None: MetalDevice.compiler_device = self.device - self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024) - self.mtl_buffers_in_flight: List[Any] = [] - self.mv_in_metal: List[memoryview] = [] - from tinygrad.features.graph.metal import MetalGraph - super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) - def synchronize(self): - for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted() - self.mv_in_metal.clear() - self.mtl_buffers_in_flight.clear() + compiler_device = None + + def __init__(self, device: str): + self.device = Metal.MTLCreateSystemDefaultDevice() + if MetalDevice.compiler_device is None: + MetalDevice.compiler_device = self.device + self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024) + self.mtl_buffers_in_flight: List[Any] = [] + self.mv_in_metal: List[memoryview] = [] + from tinygrad.features.graph.metal import MetalGraph + + super().__init__( + MetalAllocator(self), + LinearizerOptions(device="METAL"), + MetalRenderer, + compile_metal, + functools.partial(MetalProgram, self), + functools.partial(MetalGraph, self), + ) + + def synchronize(self): + for cbuf in self.mtl_buffers_in_flight: + cbuf.waitUntilCompleted() + self.mv_in_metal.clear() + self.mtl_buffers_in_flight.clear() diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 4efb244ff..0305c7cfa 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,49 +1,119 @@ import torch import numpy as np from typing import Dict, Callable -from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op +from tinygrad.ops import ( + BufferOps, + UnaryOps, + BinaryOps, + MovementOps, + TernaryOps, + ReduceOps, + Op, +) from tinygrad.device import Interpreted, Allocator from tinygrad.helpers import getenv, dtypes from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis -device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) -type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} -inverse_type_map = {v:k for k,v in type_map.items()} +device = torch.device( + "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu") +) +type_map = { + torch.float64: dtypes.float64, + torch.float16: dtypes.float16, + torch.float32: dtypes.float32, + torch.int8: dtypes.int8, + torch.int32: dtypes.int32, + torch.int64: dtypes.int64, + torch.uint8: dtypes.uint8, + torch.bool: dtypes.bool, + torch.int16: dtypes.int16, +} +inverse_type_map = {v: k for k, v in type_map.items()} + + +def output_type(x, y): + return ( + x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype + ) + -def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype def match_types(x, y, disallow_bool=False): - up = output_type(x, y) - if disallow_bool and up == torch.bool: up = torch.float - return x.type(up), y.type(up) + up = output_type(x, y) + if disallow_bool and up == torch.bool: + up = torch.float + return x.type(up), y.type(up) + def as_strided(x, arg): - if any(i < 0 for i in arg[1]): - return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]), - arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0]) - return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2]) + if any(i < 0 for i in arg[1]): + return torch.as_strided( + x.contiguous(), + arg[0], + tuple(abs(i) for i in arg[1]), + arg[2] + sum((s - 1) * a if a < 0 else 0 for (s, a) in zip(arg[0], arg[1])), + ).flip([i for i, a in enumerate(arg[1]) if a < 0]) + return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2]) + torch_fxn_for_op: Dict[Op, Callable] = { - # TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8 - #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), - BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).to(device), - UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin, - UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x), - BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x Node: - expr = [valid] if valid is not None else [] - if view.mask is not None: - acc = 1 - for ns,(x,y) in reversed(list(zip(view.shape, view.mask))): - if x != 0 or y != ns: - base = ((idx//acc) % ns) - expr += [base >= x, base < y] - acc *= ns - return Variable.ands(expr) + +def expr_node_mask(view: View, idx: Node, valid: Optional[Node] = None) -> Node: + expr = [valid] if valid is not None else [] + if view.mask is not None: + acc = 1 + for ns, (x, y) in reversed(list(zip(view.shape, view.mask))): + if x != 0 or y != ns: + base = (idx // acc) % ns + expr += [base >= x, base < y] + acc *= ns + return Variable.ands(expr) + # generate an expression if you have a single idx variable -def expr_node(view:View, idx:Optional[Node]=None) -> Node: - if idx is None: idx = Variable('idx', 0, prod(view.shape)-1) - ret: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else [] - acc = 1 - for d,s,_ in reversed(_merge_dims(view.shape, view.strides)): - ret.append(((idx//acc)%d)*s) - acc *= d - return Variable.sum(ret) +def expr_node(view: View, idx: Optional[Node] = None) -> Node: + if idx is None: + idx = Variable("idx", 0, prod(view.shape) - 1) + ret: List[Node] = ( + [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + if view.offset + else [] + ) + acc = 1 + for d, s, _ in reversed(_merge_dims(view.shape, view.strides)): + ret.append(((idx // acc) % d) * s) + acc *= d + return Variable.sum(ret) + # generate an expression if you have a variable or expression for each index -def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node: - assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" - return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) +def expr_idxs(view: View, idxs: Tuple[Node, ...]) -> Node: + assert len(idxs) == len( + view.shape + ), f"need an idx for all dimensions {idxs} vs {view.shape}" + return Variable.sum( + [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + + [ + idx * st + for idx, sh, st in zip(idxs, view.shape, view.strides) + if sh != 1 and st != 0 + ] + ) + @functools.lru_cache(maxsize=None) -def merge_views(vm2:View, vm1:View) -> Optional[View]: - if vm2.mask or vm1.offset != 0: return None # this isn't supported yet - if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None - return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask) +def merge_views(vm2: View, vm1: View) -> Optional[View]: + if vm2.mask or vm1.offset != 0: + return None # this isn't supported yet + if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): + return None + return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask) + @functools.lru_cache(maxsize=None) -def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node: - assert len(idxs) == len(shape), "need an idx for all dimensions" - acc = 1 - ret = [] - for tidx,d in reversed(list(zip(idxs, shape))): - ret.append(tidx * acc) - acc *= d - return Variable.sum(ret) +def idxs_to_idx(shape: Tuple[int, ...], idxs: Tuple[Node, ...]) -> Node: + assert len(idxs) == len(shape), "need an idx for all dimensions" + acc = 1 + ret = [] + for tidx, d in reversed(list(zip(idxs, shape))): + ret.append(tidx * acc) + acc *= d + return Variable.sum(ret) + @dataclass(frozen=True) class ShapeTracker: - views: Tuple[View, ...] - def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views" + views: Tuple[View, ...] - @staticmethod - def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),)) + def __post_init__(self): + assert isinstance(self.views, tuple) and all( + isinstance(v, View) for v in self.views + ), "ShapeTracker must be created with a tuple of Views" - @property - def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous + @staticmethod + def from_shape(shape: Tuple[sint, ...]): + return ShapeTracker((View.create(shape),)) - @property - def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape + @property + def contiguous(self) -> bool: + return len(self.views) == 1 and self.views[0].contiguous - def size(self) -> int: - if 0 in self.shape: return 0 - ret = self.expr_idxs()[0].max - while not isinstance(ret, int): ret = ret.max # TODO: this is a while loop?!? it should be more clear what max does - assert isinstance(ret, int), f"ret must be integer, {ret=} isn't" - return ret+1 + @property + def shape(self) -> Tuple[sint, ...]: + return self.views[-1].shape - def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set()) + def size(self) -> int: + if 0 in self.shape: + return 0 + ret = self.expr_idxs()[0].max + while not isinstance(ret, int): + ret = ( + ret.max + ) # TODO: this is a while loop?!? it should be more clear what max does + assert isinstance(ret, int), f"ret must be integer, {ret=} isn't" + return ret + 1 - @property - def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()]) + def vars(self) -> Set[Variable]: + return set.union(*[v.vars() for v in self.views], set()) - def unbind(self) -> ShapeTracker: return ShapeTracker(tuple(v.unbind() for v in self.views)) + @property + def var_vals(self) -> Dict[Variable, int]: + return merge_dicts([dict([v.unbind()]) for v in self.vars()]) - def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: - to_apply:List[Tuple[MovementOps, Tuple]] = [] - for v in self.views: - real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - real_offset = 0 if 0 in real_shape else (v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)) - # first, we apply the offset - # then, we make it the correct shape - # then, we apply permutations - to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset))) - # then, we apply pre expand pads - if v.mask is not None: - pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - if any(x != (0,0) for x in pre_expand_pads): - to_apply.append((MovementOps.PAD, pre_expand_pads)) - real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) - # then, we do any expands - # NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy - if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) - # lastly, we apply post expand pads - if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) - return to_apply + def unbind(self) -> ShapeTracker: + return ShapeTracker(tuple(v.unbind() for v in self.views)) - # NOTE: if a stride is not always valid, it will be None - def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: - if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides - idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] - idx, valid = self.expr_idxs(idxs) - ret: List[Optional[sint]] = [None] * len(self.views[-1].shape) - for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): - idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1) - try: ret[idxs.index(idx_maybe)] = stride_maybe - except ValueError: pass - idx_vars, valid_vars = idx.vars(), valid.vars() - for i,tidx in enumerate(idxs): - if tidx in valid_vars and not ignore_valid: ret[i] = None - elif tidx not in idx_vars: ret[i] = 0 - return tuple(ret) + def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: + to_apply: List[Tuple[MovementOps, Tuple]] = [] + for v in self.views: + real_shape = tuple(y - x for x, y in v.mask) if v.mask else v.shape + real_offset = ( + 0 + if 0 in real_shape + else ( + v.offset + + ( + sum(x * st for (x, _), st in zip(v.mask, v.strides)) + if v.mask + else 0 + ) + ) + ) + # first, we apply the offset + # then, we make it the correct shape + # then, we apply permutations + to_apply.append( + ( + MovementOps.AS_STRIDED, + ( + tuple( + [ + s if st != 0 else 1 + for s, st in zip(real_shape, v.strides) + ] + ), + v.strides, + real_offset, + ), + ) + ) + # then, we apply pre expand pads + if v.mask is not None: + pre_expand_pads = tuple( + (x, s - y) if st != 0 else (0, 0) + for (x, y), s, st in zip(v.mask, v.shape, v.strides) + ) + post_expand_pads = tuple( + (x, s - y) if st == 0 else (0, 0) + for (x, y), s, st in zip(v.mask, v.shape, v.strides) + ) + if any(x != (0, 0) for x in pre_expand_pads): + to_apply.append((MovementOps.PAD, pre_expand_pads)) + real_shape = tuple( + x + s[0] + s[1] for x, s in zip(real_shape, pre_expand_pads) + ) + # then, we do any expands + # NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy + if any(s != 1 and st == 0 for s, st in zip(real_shape, v.strides)): + to_apply.append((MovementOps.EXPAND, real_shape)) + # lastly, we apply post expand pads + if v.mask is not None and any(x != (0, 0) for x in post_expand_pads): + to_apply.append((MovementOps.PAD, post_expand_pads)) + return to_apply - def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] + # NOTE: if a stride is not always valid, it will be None + def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: + if len(self.views) == 1 and self.views[-1].mask is None: + return self.views[-1].strides + idxs: List[Node] = [ + Variable(f"idx{i}", 0, s - 1) for i, s in enumerate(self.shape) + ] + idx, valid = self.expr_idxs(idxs) + ret: List[Optional[sint]] = [None] * len(self.views[-1].shape) + for this_dim in idx.nodes if isinstance(idx, SumNode) else [idx]: + idx_maybe, stride_maybe = ( + (this_dim.a, this_dim.b) + if isinstance(this_dim, MulNode) + else (this_dim, 1) + ) + try: + ret[idxs.index(idx_maybe)] = stride_maybe + except ValueError: + pass + idx_vars, valid_vars = idx.vars(), valid.vars() + for i, tidx in enumerate(idxs): + if tidx in valid_vars and not ignore_valid: + ret[i] = None + elif tidx not in idx_vars: + ret[i] = 0 + return tuple(ret) - def _expr_idx(self, idx:Node, valid:Node) -> Tuple[Node, Node]: - for v in reversed(self.views[0:-1]): - if valid.max == 0: return NumNode(-1), valid - valid = expr_node_mask(v, idx, valid) - idx = expr_node(v, idx) - return idx, valid + def unit_stride_axes(self, ignore_valid=False) -> List[int]: + return [i for i, st in enumerate(self.real_strides(ignore_valid)) if st == 1] - def simplify(self) -> ShapeTracker: - if len(self.views) >= 2: - if (new_view := merge_views(self.views[-2], self.views[-1])) is not None: - if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}") - return ShapeTracker(self.views[:-2] + (new_view,)).simplify() - return self + def _expr_idx(self, idx: Node, valid: Node) -> Tuple[Node, Node]: + for v in reversed(self.views[0:-1]): + if valid.max == 0: + return NumNode(-1), valid + valid = expr_node_mask(v, idx, valid) + idx = expr_node(v, idx) + return idx, valid - def expr_idxs(self, idxs:Optional[Iterable[Node]]=None): - if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] - idx = expr_idxs(self.views[-1], tuple(idxs)) - valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs))) - return self._expr_idx(idx, valid) + def simplify(self) -> ShapeTracker: + if len(self.views) >= 2: + if (new_view := merge_views(self.views[-2], self.views[-1])) is not None: + if DEBUG >= 4: + print( + f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}" + ) + return ShapeTracker(self.views[:-2] + (new_view,)).simplify() + return self - def expr_node(self, idx:Union[Node,str]='idx'): - if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1) - return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx)) + def expr_idxs(self, idxs: Optional[Iterable[Node]] = None): + if idxs is None: + idxs = [Variable(f"idx{i}", 0, s - 1) for i, s in enumerate(self.shape)] + idx = expr_idxs(self.views[-1], tuple(idxs)) + valid = expr_node_mask( + self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)) + ) + return self._expr_idx(idx, valid) - def axis_is_masked(self, axis:int) -> bool: - _, valid = self.expr_idxs() - return f'idx{axis}' in [v.expr for v in valid.vars()] + def expr_node(self, idx: Union[Node, str] = "idx"): + if isinstance(idx, str): + idx = Variable(idx, 0, prod(self.shape) - 1) + return self._expr_idx( + expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx) + ) - # *** under this line are the movement ops *** + def axis_is_masked(self, axis: int) -> bool: + _, valid = self.expr_idxs() + return f"idx{axis}" in [v.expr for v in valid.vars()] - def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), )) - def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) - def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) - def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) - def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), )) + # *** under this line are the movement ops *** + + def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg),)) + + def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg),)) + + def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape),)) + + def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis),)) + + def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: + return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul),)) + + def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: + if (new_view := self.views[-1].reshape(new_shape)) is not None: + return ShapeTracker(self.views[0:-1] + (new_view,)) + return ShapeTracker(self.views + (View.create(new_shape),)) - def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: - if (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) - return ShapeTracker(self.views + (View.create(new_shape), )) # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape # TODO: if we remove movementops from lazy.py we can delete this -def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]: - acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul)) - try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new] - except ValueError: return None - return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])] +def get_contraction( + old_shape: Tuple[sint, ...], new_shape: Tuple[sint, ...] +) -> Optional[List[List[int]]]: + acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list( + itertools.accumulate(new_shape, operator.mul) + ) + try: + split = [acc_old.index(acc) + 1 if acc != 1 else 0 for acc in acc_new] + except ValueError: + return None + return [ + list(range(st, ed)) + for st, ed in zip([0] + split[:-1], split[:-1] + [len(old_shape)]) + ] diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 5ffeafa90..0d96da86a 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -3,342 +3,629 @@ import functools from math import gcd from itertools import product from tinygrad.helpers import partition -from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator, Set +from typing import ( + List, + Dict, + Callable, + Tuple, + Type, + Union, + Optional, + Any, + Iterator, + Set, +) # NOTE: Python has different behavior for negative mod and floor div than c # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod -def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node)) + +def is_sym_int(x: Any) -> bool: + return isinstance(x, (int, Node)) + class Node: - b: Union[Node, int] - min: int - max: int - def render(self, ops=None, ctx=None) -> Any: - if ops is None: ops = render_python - assert self.__class__ in (Variable, NumNode) or self.min != self.max - return ops[type(self)](self, ops, ctx) - def vars(self) -> Set[Variable]: return set() + b: Union[Node, int] + min: int + max: int - def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0)) - # expand a Node into List[Node] that enumerates the underlying Variables from min to max - # expand increments earlier variables faster than later variables (as specified in the argument) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def expand(self, idxs:Optional[Tuple[VariableOrNum, ...]]=None) -> List[Node]: - if idxs is None: idxs = (self.expand_idx(),) - return [self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs)] - @staticmethod - def iter_idxs(idxs:Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int,...]]: - yield from (x[::-1] for x in product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]])) - # substitute Variables with the values in var_vals - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__) - def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None + def render(self, ops=None, ctx=None) -> Any: + if ops is None: + ops = render_python + assert self.__class__ in (Variable, NumNode) or self.min != self.max + return ops[type(self)](self, ops, ctx) - @functools.cached_property - def key(self) -> str: return self.render(ctx="DEBUG") - @functools.cached_property - def hash(self) -> int: return hash(self.key) - def __repr__(self): return self.render(ctx="REPR") - def __str__(self): return "<"+self.key+">" - def __hash__(self): return self.hash - def __bool__(self): return not (self.max == self.min == 0) - def __eq__(self, other:object) -> bool: - if not isinstance(other, Node): return NotImplemented - return self.key == other.key - def __neg__(self): return self*-1 - def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)]) - def __radd__(self, b:int): return self+b - def __sub__(self, b:Union[Node,int]): return self+-b - def __rsub__(self, b:int): return -self+b - def __le__(self, b:Union[Node,int]): return self < (b+1) - def __gt__(self, b:Union[Node,int]): return (-self) < (-b) - def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1) - def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b)) - def __mul__(self, b:Union[Node, int]): - if b == 0: return NumNode(0) - if b == 1: return self - if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b - return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b)) - def __rmul__(self, b:int): return self*b + def vars(self) -> Set[Variable]: + return set() - # *** complex ops *** + def expand_idx(self) -> VariableOrNum: + return next((v for v in self.vars() if v.expr is None), NumNode(0)) - def __rfloordiv__(self, b:int): - if self.min > b >= 0: return NumNode(0) - if isinstance(self, NumNode): return NumNode(b // self.b) - raise RuntimeError(f"not supported: {b} // {self}") - def __floordiv__(self, b:Union[Node,int], factoring_allowed=True): - if isinstance(b, Node): - if b.__class__ is NumNode: return self // b.b - if self == b: return NumNode(1) - if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node - raise RuntimeError(f"not supported: {self} // {b}") - assert b != 0 - if b < 0: return (self//-b)*-1 - if b == 1: return self + # expand a Node into List[Node] that enumerates the underlying Variables from min to max + # expand increments earlier variables faster than later variables (as specified in the argument) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def expand(self, idxs: Optional[Tuple[VariableOrNum, ...]] = None) -> List[Node]: + if idxs is None: + idxs = (self.expand_idx(),) + return [ + self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) + for rep in Node.iter_idxs(idxs) + ] - # the numerator of div is not allowed to be negative - if self.min < 0: - offset = self.min//b - # factor out an "offset" to make the numerator positive. don't allowing factoring again - return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset - return create_node(DivNode(self, b)) + @staticmethod + def iter_idxs(idxs: Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int, ...]]: + yield from ( + x[::-1] + for x in product( + *[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]] + ) + ) - def __rmod__(self, b:int): - if self.min > b >= 0: return NumNode(b) - if isinstance(self, NumNode): return NumNode(b % self.b) - raise RuntimeError(f"not supported: {b} % {self}") - def __mod__(self, b:Union[Node,int]): - if isinstance(b, Node): - if b.__class__ is NumNode: return self % b.b - if self == b: return NumNode(0) - if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node - raise RuntimeError(f"not supported: {self} % {b}") - assert b > 0 - if b == 1: return NumNode(0) - if isinstance(self.max, int) and isinstance(self.min, int): - if self.min >= 0 and self.max < b: return self - if (self.min//b) == (self.max//b): return self - (b*(self.min//b)) - if self.min < 0: return (self - ((self.min//b)*b)) % b - return create_node(ModNode(self, b)) + # substitute Variables with the values in var_vals + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + raise RuntimeError(self.__class__.__name__) - @staticmethod - def sum(nodes:List[Node]) -> Node: - nodes = [x for x in nodes if x.max or x.min] - if not nodes: return NumNode(0) - if len(nodes) == 1: return nodes[0] + def unbind(self) -> Tuple[Node, Optional[int]]: + return ( + self.substitute( + {v: v.unbind()[0] for v in self.vars() if v.val is not None} + ), + None, + ) - mul_groups: Dict[Node, int] = {} - num_node_sum = 0 - for node in SumNode(nodes).flat_components: - if node.__class__ is NumNode: num_node_sum += node.b - elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b - else: mul_groups[node] = mul_groups.get(node, 0) + 1 - new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0] - if num_node_sum: new_nodes.append(NumNode(num_node_sum)) - return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0) + @functools.cached_property + def key(self) -> str: + return self.render(ctx="DEBUG") - @staticmethod - def ands(nodes:List[Node]) -> Node: - if not nodes: return NumNode(1) - if len(nodes) == 1: return nodes[0] - if any(not x for x in nodes): return NumNode(0) + @functools.cached_property + def hash(self) -> int: + return hash(self.key) + + def __repr__(self): + return self.render(ctx="REPR") + + def __str__(self): + return "<" + self.key + ">" + + def __hash__(self): + return self.hash + + def __bool__(self): + return not (self.max == self.min == 0) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Node): + return NotImplemented + return self.key == other.key + + def __neg__(self): + return self * -1 + + def __add__(self, b: Union[Node, int]): + return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)]) + + def __radd__(self, b: int): + return self + b + + def __sub__(self, b: Union[Node, int]): + return self + -b + + def __rsub__(self, b: int): + return -self + b + + def __le__(self, b: Union[Node, int]): + return self < (b + 1) + + def __gt__(self, b: Union[Node, int]): + return (-self) < (-b) + + def __ge__(self, b: Union[Node, int]): + return (-self) < (-b + 1) + + def __lt__(self, b: Union[Node, int]): + return create_node(LtNode(self, b)) + + def __mul__(self, b: Union[Node, int]): + if b == 0: + return NumNode(0) + if b == 1: + return self + if self.__class__ is NumNode: + return NumNode(self.b * b) if isinstance(b, int) else b * self.b + return ( + create_node(MulNode(self, b.b)) + if isinstance(b, NumNode) + else create_node(MulNode(self, b)) + ) + + def __rmul__(self, b: int): + return self * b + + # *** complex ops *** + + def __rfloordiv__(self, b: int): + if self.min > b >= 0: + return NumNode(0) + if isinstance(self, NumNode): + return NumNode(b // self.b) + raise RuntimeError(f"not supported: {b} // {self}") + + def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): + if isinstance(b, Node): + if b.__class__ is NumNode: + return self // b.b + if self == b: + return NumNode(1) + if (b - self).min > 0 and self.min >= 0: + return NumNode(0) # b - self simplifies the node + raise RuntimeError(f"not supported: {self} // {b}") + assert b != 0 + if b < 0: + return (self // -b) * -1 + if b == 1: + return self + + # the numerator of div is not allowed to be negative + if self.min < 0: + offset = self.min // b + # factor out an "offset" to make the numerator positive. don't allowing factoring again + return (self + -offset * b).__floordiv__( + b, factoring_allowed=False + ) + offset + return create_node(DivNode(self, b)) + + def __rmod__(self, b: int): + if self.min > b >= 0: + return NumNode(b) + if isinstance(self, NumNode): + return NumNode(b % self.b) + raise RuntimeError(f"not supported: {b} % {self}") + + def __mod__(self, b: Union[Node, int]): + if isinstance(b, Node): + if b.__class__ is NumNode: + return self % b.b + if self == b: + return NumNode(0) + if (b - self).min > 0 and self.min >= 0: + return self # b - self simplifies the node + raise RuntimeError(f"not supported: {self} % {b}") + assert b > 0 + if b == 1: + return NumNode(0) + if isinstance(self.max, int) and isinstance(self.min, int): + if self.min >= 0 and self.max < b: + return self + if (self.min // b) == (self.max // b): + return self - (b * (self.min // b)) + if self.min < 0: + return (self - ((self.min // b) * b)) % b + return create_node(ModNode(self, b)) + + @staticmethod + def sum(nodes: List[Node]) -> Node: + nodes = [x for x in nodes if x.max or x.min] + if not nodes: + return NumNode(0) + if len(nodes) == 1: + return nodes[0] + + mul_groups: Dict[Node, int] = {} + num_node_sum = 0 + for node in SumNode(nodes).flat_components: + if node.__class__ is NumNode: + num_node_sum += node.b + elif node.__class__ is MulNode: + mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b + else: + mul_groups[node] = mul_groups.get(node, 0) + 1 + new_nodes = [ + MulNode(a, b_sum) if b_sum != 1 else a + for a, b_sum in mul_groups.items() + if b_sum != 0 + ] + if num_node_sum: + new_nodes.append(NumNode(num_node_sum)) + return ( + create_rednode(SumNode, new_nodes) + if len(new_nodes) > 1 + else new_nodes[0] + if len(new_nodes) == 1 + else NumNode(0) + ) + + @staticmethod + def ands(nodes: List[Node]) -> Node: + if not nodes: + return NumNode(1) + if len(nodes) == 1: + return nodes[0] + if any(not x for x in nodes): + return NumNode(0) + + # filter 1s + nodes = [x for x in nodes if x.min != x.max] + return ( + create_rednode(AndNode, nodes) + if len(nodes) > 1 + else (nodes[0] if len(nodes) == 1 else NumNode(1)) + ) - # filter 1s - nodes = [x for x in nodes if x.min != x.max] - return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1)) # 4 basic node types -class Variable(Node): - def __new__(cls, expr:Optional[str], nmin:int, nmax:int): - assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}" - if nmin == nmax: return NumNode(nmin) - return super().__new__(cls) - def __init__(self, expr:Optional[str], nmin:int, nmax:int): - self.expr, self.min, self.max = expr, nmin, nmax - self._val: Optional[int] = None - @property - def val(self): - assert self._val is not None, f"Variable isn't bound, can't access val of {self}" - return self._val - def bind(self, val): - assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}" - self._val = val - return self - def unbind(self) -> Tuple[Variable, int]: - assert self.val is not None, f"cannot unbind {self}" - return Variable(self.expr, self.min, self.max), self.val - def vars(self): return {self} - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self +class Variable(Node): + def __new__(cls, expr: Optional[str], nmin: int, nmax: int): + assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}" + if nmin == nmax: + return NumNode(nmin) + return super().__new__(cls) + + def __init__(self, expr: Optional[str], nmin: int, nmax: int): + self.expr, self.min, self.max = expr, nmin, nmax + self._val: Optional[int] = None + + @property + def val(self): + assert ( + self._val is not None + ), f"Variable isn't bound, can't access val of {self}" + return self._val + + def bind(self, val): + assert ( + self._val is None and self.min <= val <= self.max + ), f"cannot bind {val} to {self}" + self._val = val + return self + + def unbind(self) -> Tuple[Variable, int]: + assert self.val is not None, f"cannot unbind {self}" + return Variable(self.expr, self.min, self.max), self.val + + def vars(self): + return {self} + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return var_vals[self] if self in var_vals else self + class NumNode(Node): - def __init__(self, num:int): - assert isinstance(num, int), f"{num} is not an int" - self.b:int = num - self.min, self.max = num, num - def bind(self, val): - assert self.b == val, f"cannot bind {val} to {self}" - return self - def __eq__(self, other): return self.b == other - def __hash__(self): return self.hash # needed with __eq__ override - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self + def __init__(self, num: int): + assert isinstance(num, int), f"{num} is not an int" + self.b: int = num + self.min, self.max = num, num + + def bind(self, val): + assert self.b == val, f"cannot bind {val} to {self}" + return self + + def __eq__(self, other): + return self.b == other + + def __hash__(self): + return self.hash # needed with __eq__ override + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self + + +def create_node(ret: Node): + assert ( + ret.min <= ret.max + ), f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}" + if ret.min == ret.max: + return NumNode(ret.min) + return ret -def create_node(ret:Node): - assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}" - if ret.min == ret.max: return NumNode(ret.min) - return ret class OpNode(Node): - def __init__(self, a:Node, b:Union[Node, int]): - self.a, self.b = a, b - self.min, self.max = self.get_bounds() - def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set()) - def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented") + def __init__(self, a: Node, b: Union[Node, int]): + self.a, self.b = a, b + self.min, self.max = self.get_bounds() + + def vars(self): + return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set()) + + def get_bounds(self) -> Tuple[int, int]: + raise NotImplementedError("must be implemented") + class LtNode(OpNode): - def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b) - def get_bounds(self) -> Tuple[int, int]: - if isinstance(self.b, int): - return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) - return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) + def __floordiv__(self, b: Union[Node, int], _=False): + return (self.a // b) < (self.b // b) + + def get_bounds(self) -> Tuple[int, int]: + if isinstance(self.b, int): + return ( + (1, 1) + if self.a.max < self.b + else (0, 0) + if self.a.min >= self.b + else (0, 1) + ) + return ( + (1, 1) + if self.a.max < self.b.min + else (0, 0) + if self.a.min >= self.b.max + else (0, 1) + ) + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) < ( + self.b if isinstance(self.b, int) else self.b.substitute(var_vals) + ) + class MulNode(OpNode): - def __lt__(self, b: Union[Node, int]): - if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b) - sgn = 1 if self.b > 0 else -1 - return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b)) - def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul - def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right - if self.b % b == 0: return self.a*(self.b//b) - if b % self.b == 0 and self.b > 0: return self.a//(b//self.b) - return Node.__floordiv__(self, b, factoring_allowed) - def __mod__(self, b: Union[Node, int]): - a = (self.a * (self.b%b)) - return Node.__mod__(a, b) - def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)) + def __lt__(self, b: Union[Node, int]): + if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: + return Node.__lt__(self, b) + sgn = 1 if self.b > 0 else -1 + return Node.__lt__(self.a * sgn, (b + abs(self.b) - 1) // abs(self.b)) + + def __mul__(self, b: Union[Node, int]): + return self.a * (self.b * b) # two muls in one mul + + def __floordiv__( + self, b: Union[Node, int], factoring_allowed=False + ): # NOTE: mod negative isn't handled right + if self.b % b == 0: + return self.a * (self.b // b) + if b % self.b == 0 and self.b > 0: + return self.a // (b // self.b) + return Node.__floordiv__(self, b, factoring_allowed) + + def __mod__(self, b: Union[Node, int]): + a = self.a * (self.b % b) + return Node.__mod__(a, b) + + def get_bounds(self) -> Tuple[int, int]: + return ( + (self.a.min * self.b, self.a.max * self.b) + if self.b >= 0 + else (self.a.max * self.b, self.a.min * self.b) + ) + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) * ( + self.b if isinstance(self.b, int) else self.b.substitute(var_vals) + ) + class DivNode(OpNode): - def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div - def get_bounds(self) -> Tuple[int, int]: - assert self.a.min >= 0 and isinstance(self.b, int) - return self.a.min//self.b, self.a.max//self.b - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) // self.b + def __floordiv__(self, b: Union[Node, int], _=False): + return self.a // (self.b * b) # two divs is one div + + def get_bounds(self) -> Tuple[int, int]: + assert self.a.min >= 0 and isinstance(self.b, int) + return self.a.min // self.b, self.a.max // self.b + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) // self.b + class ModNode(OpNode): - def __mod__(self, b: Union[Node, int]): - if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b) - return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b) - def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): - if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod - return Node.__floordiv__(self, b, factoring_allowed) - def get_bounds(self) -> Tuple[int, int]: - assert self.a.min >= 0 and isinstance(self.b, int) - return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) % self.b + def __mod__(self, b: Union[Node, int]): + if isinstance(b, Node) or isinstance(self.b, Node): + return Node.__mod__(self, b) + return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b) + + def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): + if self.b % b == 0: + return (self.a // b) % (self.b // b) # put the div inside mod + return Node.__floordiv__(self, b, factoring_allowed) + + def get_bounds(self) -> Tuple[int, int]: + assert self.a.min >= 0 and isinstance(self.b, int) + return ( + (0, self.b - 1) + if self.a.max - self.a.min >= self.b + or (self.a.min != self.a.max and self.a.min % self.b >= self.a.max % self.b) + else (self.a.min % self.b, self.a.max % self.b) + ) + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return self.a.substitute(var_vals) % self.b + class RedNode(Node): - def __init__(self, nodes:List[Node]): self.nodes = nodes - def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set()) + def __init__(self, nodes: List[Node]): + self.nodes = nodes + + def vars(self) -> Set[Variable]: + return set.union(*[x.vars() for x in self.nodes], set()) + class SumNode(RedNode): - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): - fully_divided: List[Node] = [] - rest: List[Node] = [] - if isinstance(b, SumNode): - nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode) - de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode) - if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b - if isinstance(b, Node): - for x in self.flat_components: - if x % b == 0: fully_divided.append(x // b) - else: rest.append(x) - if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b - return Node.__floordiv__(self, b, False) - if b == 1: return self - if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed) - fully_divided, rest = [], [] - _gcd = b - divisor = 1 - for x in self.flat_components: - if x.__class__ in (NumNode, MulNode): - if x.b%b == 0: fully_divided.append(x//b) - else: - rest.append(x) - if isinstance(x.b, int): - _gcd = gcd(_gcd, x.b) - if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b - else: - _gcd = 1 - else: - rest.append(x) - _gcd = 1 - if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (b//_gcd) - if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor) - return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def __mul__(self, b: Union[Node, int]): + return Node.sum([x * b for x in self.nodes]) # distribute mul into sum - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def __mod__(self, b: Union[Node, int]): - if isinstance(b, SumNode): - nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode) - de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode) - if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b - if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node - new_nodes: List[Node] = [] - for x in self.nodes: - if x.__class__ is NumNode: new_nodes.append(NumNode(x.b%b)) - elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b)) - else: new_nodes.append(x) - return Node.__mod__(Node.sum(new_nodes), b) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): + fully_divided: List[Node] = [] + rest: List[Node] = [] + if isinstance(b, SumNode): + nu_num = sum( + node.b for node in self.flat_components if node.__class__ is NumNode + ) + de_num = sum( + node.b for node in b.flat_components if node.__class__ is NumNode + ) + if nu_num > 0 and de_num and (d := nu_num // de_num) > 0: + return NumNode(d) + (self - b * d) // b + if isinstance(b, Node): + for x in self.flat_components: + if x % b == 0: + fully_divided.append(x // b) + else: + rest.append(x) + if (sum_fully_divided := create_rednode(SumNode, fully_divided)) != 0: + return sum_fully_divided + create_rednode(SumNode, rest) // b + return Node.__floordiv__(self, b, False) + if b == 1: + return self + if not factoring_allowed: + return Node.__floordiv__(self, b, factoring_allowed) + fully_divided, rest = [], [] + _gcd = b + divisor = 1 + for x in self.flat_components: + if x.__class__ in (NumNode, MulNode): + if x.b % b == 0: + fully_divided.append(x // b) + else: + rest.append(x) + if isinstance(x.b, int): + _gcd = gcd(_gcd, x.b) + if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: + divisor = x.b + else: + _gcd = 1 + else: + rest.append(x) + _gcd = 1 + if _gcd > 1: + return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // ( + b // _gcd + ) + if divisor > 1: + return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // ( + b // divisor + ) + return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b) - def __lt__(self, b:Union[Node,int]): - lhs: Node = self - if isinstance(b, int): - new_sum = [] - for x in self.nodes: - # TODO: should we just force the last one to always be the number - if isinstance(x, NumNode): b -= x.b - else: new_sum.append(x) - lhs = Node.sum(new_sum) - nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs] - assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported" - muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b) - if muls: - # NOTE: gcd in python 3.8 takes exactly 2 args - mul_gcd = b - for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above - all_others = Variable.sum(others) - if all_others.min >= 0 and all_others.max < mul_gcd: - lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd - return Node.__lt__(lhs, b) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def __mod__(self, b: Union[Node, int]): + if isinstance(b, SumNode): + nu_num = sum( + node.b for node in self.flat_components if node.__class__ is NumNode + ) + de_num = sum( + node.b for node in b.flat_components if node.__class__ is NumNode + ) + if nu_num > 0 and de_num and (d := nu_num // de_num) > 0: + return (self - b * d) % b + if isinstance(b, Node) and (b - self).min > 0: + return self # b - self simplifies the node + new_nodes: List[Node] = [] + for x in self.nodes: + if x.__class__ is NumNode: + new_nodes.append(NumNode(x.b % b)) + elif isinstance(x, MulNode): + new_nodes.append(x.a * (x.b % b)) + else: + new_nodes.append(x) + return Node.__mod__(Node.sum(new_nodes), b) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes]) + def __lt__(self, b: Union[Node, int]): + lhs: Node = self + if isinstance(b, int): + new_sum = [] + for x in self.nodes: + # TODO: should we just force the last one to always be the number + if isinstance(x, NumNode): + b -= x.b + else: + new_sum.append(x) + lhs = Node.sum(new_sum) + nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs] + assert all( + not isinstance(node, MulNode) or isinstance(node.b, int) + for node in nodes + ), "not supported" + muls, others = partition( + nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b + ) + if muls: + # NOTE: gcd in python 3.8 takes exactly 2 args + mul_gcd = b + for x in muls: + mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above + all_others = Variable.sum(others) + if all_others.min >= 0 and all_others.max < mul_gcd: + lhs, b = ( + Variable.sum([mul // mul_gcd for mul in muls]), + b // mul_gcd, + ) + return Node.__lt__(lhs, b) + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + return Variable.sum([node.substitute(var_vals) for node in self.nodes]) + + @property + def flat_components(self): # recursively expand sumnode components + new_nodes = [] + for x in self.nodes: + new_nodes += x.flat_components if isinstance(x, SumNode) else [x] + return new_nodes - @property - def flat_components(self): # recursively expand sumnode components - new_nodes = [] - for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x]) - return new_nodes class AndNode(RedNode): - def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes]) - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: - subed = [] - for node in self.nodes: - if not (sub:=node.substitute(var_vals)): return NumNode(0) - subed.append(sub) - return Variable.ands(subed) + def __floordiv__(self, b: Union[Node, int], _=True): + return Variable.ands([x // b for x in self.nodes]) + + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: + subed = [] + for node in self.nodes: + if not (sub := node.substitute(var_vals)): + return NumNode(0) + subed.append(sub) + return Variable.ands(subed) + + +def create_rednode(typ: Type[RedNode], nodes: List[Node]): + ret = typ(nodes) + if typ == SumNode: + ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes])) + elif typ == AndNode: + ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes])) + return create_node(ret) + + +def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: + return str(a) if isinstance(a, int) else a.render(ops, ctx) -def create_rednode(typ:Type[RedNode], nodes:List[Node]): - ret = typ(nodes) - if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes])) - elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes])) - return create_node(ret) -def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx) def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int: - if isinstance(a, (int, float)): return a - ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) - assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}" - return ret.b + if isinstance(a, (int, float)): + return a + ret = a.substitute({k: NumNode(v) for k, v in var_vals.items()}) + assert isinstance( + ret, NumNode + ), f"sym_infer didn't produce NumNode from {a} with {var_vals}" + return ret.b + # symbolic int sint = Union[Node, int] VariableOrNum = Union[Variable, NumNode] render_python: Dict[Type, Callable] = { - Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), - NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}", - MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", - DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})", - ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})", - LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})", - SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", - AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" -} \ No newline at end of file + Variable: lambda self, ops, ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" + if ctx == "DEBUG" + else ( + f"Variable('{self.expr}', {self.min}, {self.max})" + + (f".bind({self.val})" if self._val is not None else "") + if ctx == "REPR" + else f"{self.expr}" + ), + NumNode: lambda self, ops, ctx: f"NumNode({self.b})" + if ctx == "REPR" + else f"{self.b}", + MulNode: lambda self, ops, ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" + if isinstance(self.a, Variable) + and isinstance(self.b, Variable) + and self.a.expr + and self.b.expr + and self.b.expr < self.a.expr + else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", + DivNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}//{self.b})", + ModNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}%{self.b})", + LtNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})", + SumNode: lambda self, ops, ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", + AndNode: lambda self, ops, ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", +} diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index eb985854c..eef08e600 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -5,180 +5,393 @@ from typing import Tuple, List, Optional, Dict, cast from tinygrad.helpers import prod, all_int from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, Set, sint -@functools.lru_cache(maxsize=None) -def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: - return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape)) @functools.lru_cache(maxsize=None) -def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: - strides = [1] if shape else [] - for d in reversed(shape[1:]): strides.append(d*strides[-1]) - return filter_strides(shape, tuple(reversed(strides))) +def filter_strides(shape: Tuple[int, ...], strides: Tuple[int, ...]) -> Tuple[int, ...]: + return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape)) + @functools.lru_cache(maxsize=None) -def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: - # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...] - if not shape: return tuple() - assert len(shape) == len(strides) # state (0, 1, 2) -> (none, in-progress, done). wrt merging zero strided dimensions. - ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)] - state = 1 if mask and strides[0] == 0 and shape[0] != 1 and mask[0][1] - mask[0][0] == 1 else 0 - for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1): - if sh == 1: continue - if state == 1 or ret[-1][1] == sh * st: # mergeable - ret[-1] = (ret[-1][0] * sh, st, (sh if state == 1 else ret[-1][2] * sh) if st else 0) - else: ret.append((sh, st, sh if st else 0)) # begin new - # merging ends with either non-zero strided dim or zero strided dim with mask range > 1 - state = 1 if mask and st == 0 and mask[i][1] - mask[i][0] == 1 else (2 if state != 0 else 0) - return tuple(ret) +def strides_for_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: + strides = [1] if shape else [] + for d in reversed(shape[1:]): + strides.append(d * strides[-1]) + return filter_strides(shape, tuple(reversed(strides))) + @functools.lru_cache(maxsize=None) -def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], Optional[Tuple[sint, ...]], bool]: - if view.mask is None: return view.mask, tuple(), False - new_mask: List[Tuple[int, int]] = [] +def _merge_dims( + shape: Tuple[int, ...], + strides: Tuple[int, ...], + mask: Optional[Tuple[Tuple[int, int], ...]] = None, +) -> Tuple[Tuple[int, int, int], ...]: + # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...] + if not shape: + return tuple() + assert len(shape) == len( + strides + ) # state (0, 1, 2) -> (none, in-progress, done). wrt merging zero strided dimensions. + ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)] + state = ( + 1 + if mask and strides[0] == 0 and shape[0] != 1 and mask[0][1] - mask[0][0] == 1 + else 0 + ) + for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1): + if sh == 1: + continue + if state == 1 or ret[-1][1] == sh * st: # mergeable + ret[-1] = ( + ret[-1][0] * sh, + st, + (sh if state == 1 else ret[-1][2] * sh) if st else 0, + ) + else: + ret.append((sh, st, sh if st else 0)) # begin new + # merging ends with either non-zero strided dim or zero strided dim with mask range > 1 + state = ( + 1 + if mask and st == 0 and mask[i][1] - mask[i][0] == 1 + else (2 if state != 0 else 0) + ) + return tuple(ret) - r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape) - curr_stride, off, offsets, old_dim, new_dim, mask = 1, 0, [], next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - # off represents offset while combining masks of range one & zero stride - if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), tuple(), False # invalid mask - while len(new_mask) < len(new_shape): - (l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride +@functools.lru_cache(maxsize=None) +def _reshape_mask( + view: View, new_shape: Tuple[sint, ...] +) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], Optional[Tuple[sint, ...]], bool]: + if view.mask is None: + return view.mask, tuple(), False + new_mask: List[Tuple[int, int]] = [] - if old_dim >= new_dim: # need to split mask. - offsets.append(off) + r_masks, r_shape, r_new_shape = ( + reversed(view.mask), + reversed(view.shape), + reversed(new_shape), + ) + curr_stride, off, offsets, old_dim, new_dim, mask = ( + 1, + 0, + [], + next(r_shape, 1), + next(r_new_shape, 1), + next(r_masks, (0, 1)), + ) + # off represents offset while combining masks of range one & zero stride + if mask[1] - mask[0] < 1: + return ((0, 0),) * len(new_shape), tuple(), False # invalid mask - if old_dim == next_stride: # simply copy the mask and get next batch for merging - new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1)) - curr_stride, off, old_dim, new_dim, mask = 1, 0, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), tuple(), False # invalid mask + while len(new_mask) < len(new_shape): + (l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride - else: # mask can only be splitted if reshape doesn't cut across the mask. - if ((l % (ns := next_stride) != 0 or r % ns != 0) and l // ns != (r - 1) // ns): return view.mask, tuple(), True - new_mask.append((l % ns // curr_stride, (r - 1) % ns // curr_stride + 1)) - curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension + if old_dim >= new_dim: # need to split mask. + offsets.append(off) - elif old_dim < new_dim * curr_stride: - next_mask = next(r_masks, (0, 1)) - # combine if the mask can unfold continuously - if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True - if next_mask != (0, 1) and mask != (0, 1) and (next_mask[1] - next_mask[0] == 1): off += next_mask[0] * old_dim - mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1) + if ( + old_dim == next_stride + ): # simply copy the mask and get next batch for merging + new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1)) + curr_stride, off, old_dim, new_dim, mask = ( + 1, + 0, + next(r_shape, 1), + next(r_new_shape, 1), + next(r_masks, (0, 1)), + ) + if mask[1] - mask[0] < 1: + return ((0, 0),) * len(new_shape), tuple(), False # invalid mask - for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1) - if mask != (0, 1): return ((0, 0),) * len(new_shape), tuple(), False + else: # mask can only be splitted if reshape doesn't cut across the mask. + if (l % (ns := next_stride) != 0 or r % ns != 0) and l // ns != ( + r - 1 + ) // ns: + return view.mask, tuple(), True + new_mask.append( + (l % ns // curr_stride, (r - 1) % ns // curr_stride + 1) + ) + curr_stride, new_dim = next_stride, next( + r_new_shape, 1 + ) # need to get mask for next dimension + + elif old_dim < new_dim * curr_stride: + next_mask = next(r_masks, (0, 1)) + # combine if the mask can unfold continuously + if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: + return view.mask, tuple(), True + if ( + next_mask != (0, 1) + and mask != (0, 1) + and (next_mask[1] - next_mask[0] == 1) + ): + off += next_mask[0] * old_dim + mask, old_dim = ( + next_mask[0] * old_dim + l, + (next_mask[1] - 1) * old_dim + r, + ), old_dim * next(r_shape, 1) + + for ( + mask + ) in ( + r_masks + ): # if the old shape has leading 1s, need to make sure their mask is (0,1) + if mask != (0, 1): + return ((0, 0),) * len(new_shape), tuple(), False + + return tuple(reversed(new_mask)), tuple(offsets), False - return tuple(reversed(new_mask)), tuple(offsets), False @dataclass(frozen=True) class View: - shape:Tuple[sint, ...] - strides:Tuple[sint, ...] - offset:sint - mask:Optional[Tuple[Tuple[sint, sint], ...]] - contiguous:bool + shape: Tuple[sint, ...] + strides: Tuple[sint, ...] + offset: sint + mask: Optional[Tuple[Tuple[sint, sint], ...]] + contiguous: bool - @staticmethod - @functools.lru_cache(maxsize=None) - def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None): - strides = filter_strides(shape, strides) if strides else strides_for_shape(shape) - contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape) - return View(shape, strides, offset, mask, contiguous) + @staticmethod + @functools.lru_cache(maxsize=None) + def create( + shape: Tuple[sint, ...], + strides: Optional[Tuple[sint, ...]] = None, + offset: sint = 0, + mask: Optional[Tuple[Tuple[sint, sint], ...]] = None, + ): + strides = ( + filter_strides(shape, strides) if strides else strides_for_shape(shape) + ) + contiguous = ( + offset == 0 and mask is None and strides == strides_for_shape(shape) + ) + return View(shape, strides, offset, mask, contiguous) - def vars(self) -> Set[Variable]: - flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() - return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set()) + def vars(self) -> Set[Variable]: + flatten_mask = ( + tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() + ) + return functools.reduce( + operator.or_, + [ + x.vars() + for x in self.shape + self.strides + (self.offset,) + flatten_mask + if isinstance(x, Node) + ], + set(), + ) - def unbind(self) -> View: - unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None} - new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape]) - new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides]) - new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars) - new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None - return View.create(new_shape, new_strides, new_offset, new_mask) + def unbind(self) -> View: + unbound_vars: Dict[VariableOrNum, Node] = { + v: v.unbind()[0] for v in self.vars() if v.val is not None + } + new_shape = tuple( + [ + s if isinstance(s, int) else s.substitute(unbound_vars) + for s in self.shape + ] + ) + new_strides = tuple( + [ + s if isinstance(s, int) else s.substitute(unbound_vars) + for s in self.strides + ] + ) + new_offset = ( + self.offset + if isinstance(self.offset, int) + else self.offset.substitute(unbound_vars) + ) + new_mask = ( + tuple( + ( + a if isinstance(a, int) else a.substitute(unbound_vars), + b if isinstance(b, int) else b.substitute(unbound_vars), + ) + for (a, b) in self.mask + ) + if self.mask is not None + else None + ) + return View.create(new_shape, new_strides, new_offset, new_mask) - # MovementOps live here now + # MovementOps live here now - def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: - offset = sum([s * x[0] for s, x in zip(self.strides,arg)]) - if self.mask: - # move the old mask - nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)]) - # merge the masks if we have two - mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask - shape = [y-x for x,y in arg] - return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask) + def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: + offset = sum([s * x[0] for s, x in zip(self.strides, arg)]) + if self.mask: + # move the old mask + nmask = tuple( + [ + (max(0, min(mx - ax, ay - ax)), max(0, min(my - ax, ay - ax))) + for (mx, my), (ax, ay) in zip(self.mask, arg) + ] + ) + # merge the masks if we have two + mask = ( + tuple( + [ + (max(mx1, mx2), min(my1, my2)) + for (mx1, my1), (mx2, my2) in zip(nmask, mask) + ] + ) + if mask is not None + else nmask + ) + shape = [y - x for x, y in arg] + return View.create( + tuple(s.b if isinstance(s, NumNode) else s for s in shape), + self.strides, + self.offset + offset, + mask, + ) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View: - assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape) - if any(b or e for b, e in arg): - zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)]) - mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)]) - return self.__unsafe_resize(zvarg, mask=mask) - return self + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View: + assert all((b >= 0 and e >= 0) for b, e in arg) and len(arg) == len(self.shape) + if any(b or e for b, e in arg): + zvarg = tuple([(-b, s + e) for s, (b, e) in zip(self.shape, arg)]) + mask = tuple([(b, s + b) for s, (b, _) in zip(self.shape, arg)]) + return self.__unsafe_resize(zvarg, mask=mask) + return self - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View: - assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape) - return self.__unsafe_resize(arg) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View: + assert all((b >= 0 and e <= s) for s, (b, e) in zip(self.shape, arg)) and len( + arg + ) == len(self.shape) + return self.__unsafe_resize(arg) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def expand(self, new_shape: Tuple[sint, ...]) -> View: - if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}") - if 0 in self.shape: - assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}" - return View.create(new_shape) - assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}" - # NOTE: can the mask ever be (0,0)? - mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None - return View.create(new_shape, self.strides, self.offset, mask) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def expand(self, new_shape: Tuple[sint, ...]) -> View: + if len(new_shape) != len(self.shape): + raise ValueError( + f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}" + ) + if 0 in self.shape: + assert all( + (s == x == 0) or (s > 0 and (x % s) == 0) + for s, x in zip(self.shape, new_shape) + ), f"can't expand {self.shape} into {new_shape}" + return View.create(new_shape) + assert all( + (s == x or (s == 1 and st == 0)) + for s, x, st in zip(self.shape, new_shape, self.strides) + ), f"can't expand {self.shape} into {new_shape}" + # NOTE: can the mask ever be (0,0)? + mask = ( + tuple( + [ + (((0, 0) if m != (0, 1) else (0, ns)) if s != ns else m) + for m, s, ns in zip(self.mask, self.shape, new_shape) + ] + ) + if self.mask + else None + ) + return View.create(new_shape, self.strides, self.offset, mask) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def permute(self, axis: Tuple[int, ...]) -> View: - assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" - assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" - return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def permute(self, axis: Tuple[int, ...]) -> View: + assert all( + isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis + ), f"invalid permute {axis} for {self.shape}" + assert len(set(axis)) == len(axis) and len(axis) == len( + self.shape + ), f"can't permute {self.shape} with {axis}" + return View.create( + tuple([self.shape[a] for a in axis]), + tuple([self.strides[a] for a in axis]), + self.offset, + tuple([self.mask[a] for a in axis]) if self.mask is not None else None, + ) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def stride(self, mul: Tuple[int, ...]) -> View: - # except for the negative case, you can build this from the others. invertible in the negative case - assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" - strides = tuple([z*m for z,m in zip(self.strides, mul)]) - new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)]) - offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0]) - mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None - return View.create(new_shape, strides, self.offset + offset, mask) + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def stride(self, mul: Tuple[int, ...]) -> View: + # except for the negative case, you can build this from the others. invertible in the negative case + assert all( + isinstance(x, int) and x != 0 for x in mul + ), f"invalid stride {mul} for {self.shape}" + strides = tuple([z * m for z, m in zip(self.strides, mul)]) + new_shape = tuple( + [(s + (abs(m) - 1)) // abs(m) for s, m in zip(self.shape, mul)] + ) + offset = sum( + [(s - 1) * z for s, z, m in zip(self.shape, self.strides, mul) if m < 0] + ) + mask = ( + tuple( + [ + ( + ((mx if m > 0 else s - my) + (abs(m) - 1)) // abs(m), + ((my if m > 0 else s - mx) + (abs(m) - 1)) // abs(m), + ) + for (mx, my), s, m in zip(self.mask, self.shape, mul) + ] + ) + if self.mask is not None + else None + ) + return View.create(new_shape, strides, self.offset + offset, mask) - @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]: - if self.shape == new_shape: return self + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none + def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]: + if self.shape == new_shape: + return self - assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}" - if 0 in self.shape: - assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" - return View.create(new_shape) - # check for the same size - if all_int(self.shape): - assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" - if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + assert all( + x >= 0 for x in new_shape + ), f"shape can't contain negative numbers {new_shape}" + if 0 in self.shape: + assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" + return View.create(new_shape) + # check for the same size + if all_int(self.shape): + assert all( + isinstance(s, (int, Variable)) for s in new_shape + ), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" + if prod(self.shape) != prod( + [s if isinstance(s, int) else cast(Variable, s).val for s in new_shape] + ): + raise ValueError( + f"size mismatched, can't reshape {self.shape=} -> {new_shape=}" + ) - if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None + if new_shape == () and self.mask and any(mx == my for (mx, my) in self.mask): + return None - # after the asserts, it's okay to check contiguous - if self.contiguous: return View.create(new_shape) + # after the asserts, it's okay to check contiguous + if self.contiguous: + return View.create(new_shape) - strides, r_new_shape = [], reversed(new_shape) - for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): - acc, new_stride = 1, s - while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)): - strides.append(new_stride if new_dim !=1 else 0) - if new_dim == 1: continue - new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0) - if acc != merged_dim: break - else: - strides += [0,] * (len(new_shape) - len(strides)) - mask, off_mask, extra = _reshape_mask(self, new_shape) - total_offset = sum([off * s for off, s in zip(off_mask, strides)]) if off_mask else 0 - if not extra: return View.create(new_shape, tuple(reversed(strides)), self.offset - total_offset, mask) + strides, r_new_shape = [], reversed(new_shape) + for merged_dim, s, real_dim in reversed( + _merge_dims(self.shape, self.strides, self.mask) + ): + acc, new_stride = 1, s + while ( + acc <= merged_dim + and acc != merged_dim + and (new_dim := next(r_new_shape, None)) + ): + strides.append(new_stride if new_dim != 1 else 0) + if new_dim == 1: + continue + new_stride *= new_dim if (acc := acc * new_dim) < real_dim else 0 + if acc != merged_dim: + break + else: + strides += [ + 0, + ] * (len(new_shape) - len(strides)) + mask, off_mask, extra = _reshape_mask(self, new_shape) + total_offset = ( + sum([off * s for off, s in zip(off_mask, strides)]) if off_mask else 0 + ) + if not extra: + return View.create( + new_shape, + tuple(reversed(strides)), + self.offset - total_offset, + mask, + ) - return None + return None diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 835fa1725..2955fa9c4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,839 +1,1864 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations import time, math -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set +from typing import ( + List, + Tuple, + Callable, + Optional, + ClassVar, + Type, + Union, + Sequence, + Any, + Iterable, + Set, +) from collections import defaultdict from functools import partialmethod, reduce from itertools import accumulate import numpy as np -from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up +from tinygrad.helpers import ( + ImageDType, + argfix, + make_pair, + getenv, + IMAGE, + DEBUG, + flatten, + DType, + dtypes, + prod, + all_int, + round_up, +) from tinygrad.lazy import LazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import sint from tinygrad.realize import run_schedule + class Function: - def __init__(self, device:str, *tensors:Tensor): - self.device = device - self.needs_input_grad = [t.requires_grad for t in tensors] - self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False - if self.requires_grad: self.parents = tensors + def __init__(self, device: str, *tensors: Tensor): + self.device = device + self.needs_input_grad = [t.requires_grad for t in tensors] + self.requires_grad = ( + True + if any(self.needs_input_grad) + else None + if None in self.needs_input_grad + else False + ) + if self.requires_grad: + self.parents = tensors - def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") + def forward(self, *args, **kwargs): + raise NotImplementedError(f"forward not implemented for {type(self)}") + + def backward(self, *args, **kwargs): + raise RuntimeError(f"backward not implemented for {type(self)}") + + @classmethod + def apply(fxn: Type[Function], *x: Tensor, **kwargs) -> Tensor: + ctx = fxn(x[0].device, *x) + ret = Tensor( + ctx.forward(*[t.lazydata for t in x], **kwargs), + device=ctx.device, + requires_grad=ctx.requires_grad, + ) + if ctx.requires_grad and not Tensor.no_grad: + ret._ctx = ctx # used by autograd engine + return ret - @classmethod - def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: - ctx = fxn(x[0].device, *x) - ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) - if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine - return ret import tinygrad.mlops as mlops # **** start with two base classes, Tensor and Function **** + class Tensor: - __slots__ = "lazydata", "requires_grad", "grad", "_ctx" - __deletable__ = ('_ctx',) - training: ClassVar[bool] = False - class train: - def __init__(self, val=True): self.val = val - def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev - - no_grad: ClassVar[bool] = False - default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): - assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" - device = Device.canonicalize(device) - # tensors have gradients, buffers do not - self.grad: Optional[Tensor] = None - - # NOTE: this can be in three states. False and None: no gradient, True: gradient - # None (the default) will be updated to True if it's put in an optimizer - self.requires_grad: Optional[bool] = requires_grad - - # internal variables used for autograd graph construction - self._ctx: Optional[Function] = None - if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, (int, float)): - data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) - elif data is None or data.__class__ is list: - assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - data = LazyBuffer.fromCPU(np.array([] if data is None else data, dtype=(dtype or Tensor.default_type).np)) - elif isinstance(data, bytes): - data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) - elif isinstance(data, np.ndarray): - assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - if data.shape == (): - data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) - else: - data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) - - # data is a LazyBuffer, but it might be on the wrong device - if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") - self.lazydata = data if data.device == device else data.copy_to_device(device) - - def __repr__(self): - return f"" - - # Python has a non moving GC, so this should be okay - def __hash__(self): return id(self) - - @property - def device(self) -> str: return self.lazydata.device - - @property - def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape - - @property - def dtype(self) -> DType: return self.lazydata.dtype - - # ***** data handlers **** - - @staticmethod - def corealize(lst:Iterable[Tensor]): - seen:Set[LazyBuffer] = set() - sched = [] - for t in lst: sched += t.lazydata.schedule(seen) - run_schedule(sched) - - def realize(self) -> Tensor: - run_schedule(self.lazydata.schedule()) - return self - - def assign(self, x) -> Tensor: - # TODO: this is a hack for writing to DISK. remove with working assign - if self.device.startswith("DISK"): - if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) - self.contiguous().realize().lazydata.realized.copyin(x.numpy().data) - return self - if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" - assert not x.requires_grad # self requires_grad is okay? - if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") - if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized - self.lazydata = x.lazydata - return self - - def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) - def numpy(self) -> np.ndarray: - assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" - assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" - if 0 in self.shape: return np.zeros(self.shape, dtype=self.dtype.np) - return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape) - def item(self) -> Union[float, int]: - assert self.numel() == 1, "must have one element for item" - return self.realize().lazydata.realized.toCPU().item() - - def to(self, device:Optional[str]) -> Tensor: - if device is None or device == self.device: return self - ret = Tensor(self.lazydata, device) - if self.grad: ret.grad = self.grad.to(device) - return ret - - def to_(self, device:Optional[str]): - if device is None or device == self.device: return - if self.grad: self.grad = self.grad.to_(device) - _ret = Tensor(self.lazydata, device) - self.lazydata = _ret.lazydata - - # ***** creation llop entrypoint ***** - - @staticmethod - def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): - assert isinstance(sz, int), f"cannot create with symbolic size {sz}" - return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) - - @staticmethod - def empty(*shape, **kwargs): - return Tensor._loadop(LoadOps.EMPTY, prod((shape:=argfix(*shape))), **kwargs).reshape(shape) - - _seed: int = int(time.time()) - @staticmethod - def manual_seed(seed=0): Tensor._seed = seed - - @staticmethod - def rand(*shape, **kwargs): - return Tensor._loadop(LoadOps.CUSTOM, prod((shape:=argfix(*shape))), arg=custom_random, **kwargs).reshape(shape) - - # ***** creation helper functions ***** - - @staticmethod - def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) - - @staticmethod - def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) - - @staticmethod - def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) - - @staticmethod - def arange(start, stop=None, step=1, **kwargs): - if stop is None: stop, start = start, 0 - return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) - - @staticmethod - def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) - - def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) - def zeros_like(self, **kwargs): return self.full_like(0, **kwargs) - def ones_like(self, **kwargs): return self.full_like(1, **kwargs) - - # ***** rng hlops ***** - - @staticmethod - def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor: - # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - src = Tensor.rand(2, *shape, **kwargs) - return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) - - @staticmethod - def randint(*shape, low=0, high=10, **kwargs) -> Tensor: - return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32) - - @staticmethod - def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean - - @staticmethod - def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor: - dtype = kwargs.pop("dtype", Tensor.default_type) - return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low - - @staticmethod - def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(shape)**-0.5) - - # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform - @staticmethod - def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5) - - # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ - @staticmethod - def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) - return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) - - # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ - @staticmethod - def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: - std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) - return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) - - def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor: - assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" - assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" - weight = self.unsqueeze(0) if self.ndim == 1 else self - cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1) - unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) - indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) - return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) - - # ***** toposort and backward pass ***** - - def deepwalk(self): - def _deepwalk(node, visited, nodes): - visited.add(node) - if getattr(node, "_ctx", None): - for i in node._ctx.parents: - if i not in visited: _deepwalk(i, visited, nodes) - nodes.append(node) - return nodes - return _deepwalk(self, set(), []) - - def backward(self) -> Tensor: - assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})" - - # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous - # this is "implicit gradient creation" - self.grad = Tensor(1, device=self.device, requires_grad=False) - - for t0 in reversed(self.deepwalk()): - assert (t0.grad is not None) - grads = t0._ctx.backward(t0.grad.lazydata) - grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None - for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] - for t, g in zip(t0._ctx.parents, grads): - if g is not None and t.requires_grad: - assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - t.grad = g if t.grad is None else (t.grad + g) - del t0._ctx - return self - - # ***** movement mlops ***** - - def reshape(self, shape, *args) -> Tensor: - new_shape = argfix(shape, *args) - return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) - def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) - def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) - def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self - def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor: - if all(x is None or x == (0,0) for x in arg): return self - ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg))) - return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value) - - # ***** movement hlops ***** - - # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element - # - A slice i:j returns the elements with indices in [i, j) - # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence - # - Negative values for i and j are taken relative to the end of the sequence - # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence - # - Indexing with None on a given axis will add a new dimension of size one before that axis - # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends). - # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len). - # - Strides > 1 and < 0 are now allowed!: - # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional) - # - Idea of stride < 0 support: - # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. - # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink): - # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s]. - # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] - # is possible. - # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s]. - # - Fancy indexing and combined indexing is supported - # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing - # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively - # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim - # - There's a special case where a permute is needed at the end: - # - if first Tensor passed in (expand dims) is not at dim 0 - # - and following Tensors does not follow consecutively to the end of fancy indexing's dims - def __getitem__(self, indices) -> Tensor: # indices: Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] - def normalize_int(e, i, dim_sz): - if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 - raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") - - # TODO: if indices is a tuple of any sequence, or if indices is a list, it's for advanced indexing - orig_slices = list(indices) if isinstance(indices, tuple) else [indices] - count = defaultdict(list) - for i,v in enumerate(orig_slices): count[type(v)].append(i) - - # TODO: boolean indices - if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor]) + len(count[list])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") - if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')") - - # replace ellipsis with equivalent number of slice(None) - # TODO: move all slice(None) to the end and transpose non-None to the front - ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices) - orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) - - valid_slices = [v for v in orig_slices if v is not None] - valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] - - start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) - new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides)) - sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0]) - new_shape = sliced_tensor.shape - if any(abs(s) != 1 for s in strides): - strides = tuple(abs(s) for s in strides) - # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded] - padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape))) - # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] - reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))) - new_shape = reshaped_tensor.shape[::2] - # Shrink: do [:, 0] - sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))) - - final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0 - for i,s in enumerate(orig_slices): - if s is None: final_shape.append(1) - else: # s is int or slice or Tensor - dim_shape = next(it_shape) - if isinstance(s, list): s = Tensor(s) - if isinstance(s, int): dim_collapsed += 1 - else: - assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}" - final_shape.append(dim_shape) - if isinstance(s, Tensor): - tensors.append(s) - dim.append(i-dim_collapsed) - ret = sliced_tensor.reshape(tuple(final_shape)) - - if tensors: # Fancy/tensor indexing - # normalize idx - # TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm - idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)] - max_dim = max(i.ndim for i in idx) - # compute sum_dim, arange, and idx - sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)] - arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))] - first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))] - rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)] - idx = first_idx + rest_idx - ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:]) - # iteratively fancy index - for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd) - # special permute case - if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)): - ret_dims = list(range(ret.ndim)) - ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:]) - return ret - - def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v) - - # NOTE: using slice is discouraged and things should migrate to pad and shrink - def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: - arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) - padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) - return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) - - def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor: - assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" - assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" - if dim < 0: dim += self.ndim - idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) - permarg = list(range(self.ndim)) - permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] - return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) - - def cat(self, *args:Tensor, dim:int=0) -> Tensor: - dim = (dim + len(self.shape)) if dim < 0 else dim - assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) - catargs = [self, *args] - assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated" - shapes = [s.shape[dim] for s in catargs] - shape_cumsum = [0, *accumulate(shapes)] - slc:List[List[Tuple[sint, sint]]] = [[(0, 0) for _ in self.shape] for _ in catargs] - for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp) - return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) - - @staticmethod - def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor: - first = tensors[0].unsqueeze(dim) - unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]] - # checks for shapes and number of dimensions delegated to cat - return first.cat(*unsqueezed_tensors, dim=dim) - - def repeat(self, repeats:Sequence[int]) -> Tensor: - base_shape = (1,) * (len(repeats) - self.ndim) + self.shape - new_shape = [x for b in base_shape for x in [1, b]] - expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] - final_shape = [r*s for r,s in zip(repeats, base_shape)] - return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) - - def chunk(self, num:int, dim:int=0) -> List[Tensor]: - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num) - slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)] - return [self[tuple(sl)] for sl in slice_params] - - def squeeze(self, dim:Optional[int]=None) -> Tensor: - if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1]) - if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior - if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") - if dim < 0: dim += self.ndim - return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim]) - - def unsqueeze(self, dim:int) -> Tensor: - if dim < 0: dim = len(self.shape) + dim + 1 - return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) - - # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0) -> Tensor: - slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1] - return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value) - - @property - def T(self) -> Tensor: return self.transpose() - def transpose(self, ax1=1, ax2=0) -> Tensor: - order = list(range(len(self.shape))) - order[ax1], order[ax2] = order[ax2], order[ax1] - return self.permute(order) - def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,)) - - # ***** reduce ops ***** - - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor: - axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) - axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] - shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn]) - ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) - return ret if keepdim else ret.reshape(shape=shape) - - def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim) - def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim) - def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) - - def mean(self, axis=None, keepdim=False): - assert all_int(self.shape), "does not support symbolic shape" - out = self.sum(axis=axis, keepdim=keepdim) - return out.mul(prod(out.shape)/prod(self.shape)) if 0 not in self.shape else out - def std(self, axis=None, keepdim=False, correction=1): - assert all_int(self.shape), "does not support symbolic shape" - square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt() - def _softmax(self, axis): - m = self - self.max(axis=axis, keepdim=True) - e = m.exp() - return m, e, e.sum(axis=axis, keepdim=True) - - def softmax(self, axis=-1): - _, e, ss = self._softmax(axis) - return e.div(ss) - - def log_softmax(self, axis=-1): - m, _, ss = self._softmax(axis) - return m - ss.log() - - def argmax(self, axis=None, keepdim=False): - if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) - return prod(self.shape) - idx.max() - 1 - axis = axis + len(self.shape) if axis < 0 else axis - m = self == self.max(axis=axis, keepdim=True) - idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) - return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1 - def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim) - - # ***** processing ops ***** - - def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: - assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" - assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}" - s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) - assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" - slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] - if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): - o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] - e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding - xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) - # slide by dilation - xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) - xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) - xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))) - # handle stride, and permute to move reduce to the end - xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_))) - xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))) - xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_))) - return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))]) - # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker - o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)] - xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)]) - xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_)))) - xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))) - return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))]) - - # NOTE: these work for more than 2D - def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - - def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: - HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) - x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing) - stride = make_pair(stride, len(HW)) - if any(s>1 for s in stride): - x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:])) - x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) - x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) - return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding) - - wino = int(getenv("WINO", "0")) - def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: - (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] - assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" - padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) - - # conv2d is a pooling op (with padding) - x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) - rcout, oyx = cout//groups, x.shape[2:-len(HW)] - if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino: - # normal conv - x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) - - # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) - return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) - - # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 - def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) - HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles - winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] - winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] - winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time - - # todo: stride == dilation - # use padding to round up to 4x4 output tiles - d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI) - d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx) - tyx = d.shape[-len(HWI):] # dim of tiling - - g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front - - # compute 6x6 winograd tiles: GgGt, BtdB - gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) - dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) - - ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) - - ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) - ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final - - return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - - def dot(self, w:Tensor) -> Tensor: - n1, n2 = len(self.shape), len(w.shape) - assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" - x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) - w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) - return (x*w).sum(-1) - - def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) - def cumsum(self, axis:int=0) -> Tensor: - # TODO: someday the optimizer will find this on it's own - # for now this is a two stage cumsum - SPLIT = 256 - if self.shape[axis] <= SPLIT*2: return self._cumsum(axis) - ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0)) - ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1]//SPLIT, SPLIT)._cumsum(-1) - base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1] - base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) - def fix(x:Tensor): return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[..., -self.shape[axis]:].transpose(axis,-1) - return fix(ret) + fix(base_add) - - @staticmethod - def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) - def tril(self, k:int=0) -> Tensor: - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self) - - # ***** mlops (unary) ***** - - def neg(self): return mlops.Neg.apply(self) - def contiguous(self): return mlops.Contiguous.apply(self) - def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) - def log(self): return mlops.Log.apply(self) - def log2(self): return mlops.Log.apply(self)/math.log(2) - def exp(self): return mlops.Exp.apply(self) - def exp2(self): return mlops.Exp.apply(self*math.log(2)) - def relu(self): return mlops.Relu.apply(self) - def sigmoid(self): return mlops.Sigmoid.apply(self) - def sin(self): return mlops.Sin.apply(self) - def sqrt(self): return mlops.Sqrt.apply(self) - def rsqrt(self): return (1/self).sqrt() - def cos(self): return ((math.pi/2)-self).sin() - def tan(self): return self.sin() / self.cos() - - # ***** math functions (unary) ***** - - def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype) - def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b) - def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b) - - def square(self): return self*self - def clip(self, min_, max_): return self.maximum(min_).minimum(max_) - def abs(self): return self.relu() + (-self).relu() - def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) - def reciprocal(self): return 1.0/self - - # ***** activation functions (unary) ***** - - def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu() - def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) - def swish(self): return self * self.sigmoid() - def silu(self): return self.swish() # The SiLU function is also known as the swish function. - def relu6(self): return self.relu() - (self-6).relu() - def hardswish(self): return self * (self+3).relu6() * (1/6) - def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0 - def sinh(self): return (self.exp() - self.neg().exp()) / 2 - def cosh(self): return (self.exp() + self.neg().exp()) / 2 - def atanh(self): return ((1 + self)/(1 - self)).log() / 2 - def asinh(self): return (self + (self.square() + 1).sqrt()).log() - def acosh(self): return (self + (self.square() - 1).sqrt()).log() - def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val) - def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh()) - def quick_gelu(self): return self * (self * 1.702).sigmoid() - def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu() - def mish(self): return self * self.softplus().tanh() - def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log() - def softsign(self): return self / (1 + self.abs()) - - # ***** broadcasted binary mlops ***** - - def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]: - x: Tensor = self - if not isinstance(y, Tensor): - if 0 in x.shape: return x, x.full_like(y) - y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) - if reverse: x, y = y, x - if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) - - shape_delta = len(xshape) - len(yshape) - if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape) - elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape) - if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) - - shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)]) - if xshape != shape_ret: x = x.expand(shape_ret) - if yshape != shape_ret: y = y.expand(shape_ret) - return (x, y) - - def _to_float(self, x:Union[Tensor, float]): - return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \ - and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x - - def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self - def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self) - def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self) - if x.__class__ is not Tensor and x == -1.0: return -self - return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) - def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - if x.__class__ is not Tensor and not reverse: - # simple pow identities - if x < 0: return self.reciprocal().pow(-x) - if x == 3.0: return self*self*self - if x == 2.0: return self*self - if x == 1.0: return self - if x == 0.5: return self.sqrt() - if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() - ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp() - # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power) - sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos() - # we only need to correct the sign if the base is negative - base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2 - # we need 0 to be positive so we need to correct base_sign when the base is 0 - base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x)))))) - # inject nan if the base is negative and the power is not an integer - to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign - inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan") - return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) - def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) - - def maximum(self, x:Union[Tensor, float]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) - def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) - - def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): - x_,y = self._broadcasted(input_) - x,z = x_._broadcasted(other) - return mlops.Where.apply(x, *y._broadcasted(z)) - - # ***** op wrappers (wasted lines to make the typechecker happy) ***** - - def __neg__(self) -> Tensor: return self.neg() - - def __add__(self, x) -> Tensor: return self.add(x) - def __sub__(self, x) -> Tensor: return self.sub(x) - def __mul__(self, x) -> Tensor: return self.mul(x) - def __pow__(self, x) -> Tensor: return self.pow(x) - def __truediv__(self, x) -> Tensor: return self.div(x) - def __matmul__(self, x) -> Tensor: return self.matmul(x) - - def __radd__(self, x) -> Tensor: return self.add(x, True) - def __rsub__(self, x) -> Tensor: return self.sub(x, True) - def __rmul__(self, x) -> Tensor: return self.mul(x, True) - def __rpow__(self, x) -> Tensor: return self.pow(x, True) - def __rtruediv__(self, x) -> Tensor: return self.div(x, True) - def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True) - - def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) - def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) - def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) - def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x)) - def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) - def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) - - def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)) - def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)) - def __ge__(self, x) -> Tensor: return 1.0-(self Tensor: return 1.0-(self>x) - def __ne__(self, x) -> Tensor: return (selfx) # type: ignore[override] - def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore[override] - - # ***** functional nn ops ***** - - def linear(self, weight:Tensor, bias:Optional[Tensor]=None): - x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) - return x.add(bias) if bias is not None else x - - def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self) - - def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: - y = (self - self.mean(axis, keepdim=True)) - return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) - - def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor: - x = (self - mean.reshape(shape=[1, -1, 1, 1])) - if weight: x = x * weight.reshape(shape=[1, -1, 1, 1]) - ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) - return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret - - def dropout(self, p=0.5) -> Tensor: - if not Tensor.training or p == 0: return self - mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool) - return self * mask * (1/(1.0 - p)) - - def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: - # NOTE: it works if key, value have symbolic shape - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) - if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0) - return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value - - def binary_crossentropy(self, y:Tensor) -> Tensor: - return (-y*self.log() - (1-y)*(1-self).log()).mean() - - def binary_crossentropy_logits(self, y:Tensor) -> Tensor: - return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean() - - def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: - # NOTE: self is a logits input - loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) - y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) - return self.log_softmax().mul(y).sum() / loss_mask.sum() - - # ***** cast ops ***** - - def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self - def bitcast(self, dtype:DType) -> Tensor: - assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes" - return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self - def float(self) -> Tensor: return self.cast(dtypes.float32) - def half(self) -> Tensor: return self.cast(dtypes.float16) - - # ***** convenience stuff ***** - - @property - def ndim(self) -> int: return len(self.shape) - def numel(self) -> sint: return prod(self.shape) - def element_size(self) -> int: return self.dtype.itemsize - def nbytes(self) -> int: return self.numel() * self.element_size() - def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) + __slots__ = "lazydata", "requires_grad", "grad", "_ctx" + __deletable__ = ("_ctx",) + training: ClassVar[bool] = False + + class train: + def __init__(self, val=True): + self.val = val + + def __enter__(self): + self.prev, Tensor.training = Tensor.training, self.val + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + Tensor.training = self.prev + + no_grad: ClassVar[bool] = False + default_type: ClassVar[DType] = dtypes.float32 + + def __init__( + self, + data: Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], + device: Optional[str] = None, + dtype: Optional[DType] = None, + requires_grad: Optional[bool] = None, + ): + assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" + device = Device.canonicalize(device) + # tensors have gradients, buffers do not + self.grad: Optional[Tensor] = None + + # NOTE: this can be in three states. False and None: no gradient, True: gradient + # None (the default) will be updated to True if it's put in an optimizer + self.requires_grad: Optional[bool] = requires_grad + + # internal variables used for autograd graph construction + self._ctx: Optional[Function] = None + if isinstance(data, LazyBuffer): + assert ( + dtype is None or dtype == data.dtype + ), "dtype doesn't match, and casting isn't supported" + elif isinstance(data, (int, float)): + data = LazyBuffer.loadop( + LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data + ) + elif data is None or data.__class__ is list: + assert ( + dtype is None or dtype.np is not None + ), f"{dtype} doesn't have a numpy dtype" + data = LazyBuffer.fromCPU( + np.array( + [] if data is None else data, + dtype=(dtype or Tensor.default_type).np, + ) + ) + elif isinstance(data, bytes): + data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) + elif isinstance(data, np.ndarray): + assert ( + dtype is None or dtype.np is not None + ), f"{dtype} doesn't have a numpy dtype" + if data.shape == (): + data = LazyBuffer.loadop( + LoadOps.CONST, + tuple(), + dtype or dtypes.from_np(data.dtype), + device, + data.item(), + ) + else: + data = LazyBuffer.fromCPU( + data.astype(dtype.np) + if dtype is not None and dtype.np is not None + else data + ) + + # data is a LazyBuffer, but it might be on the wrong device + if not isinstance(data, LazyBuffer): + raise RuntimeError( + f"can't create Tensor from {data!r} with type {type(data)}" + ) + self.lazydata = data if data.device == device else data.copy_to_device(device) + + def __repr__(self): + return f"" + + # Python has a non moving GC, so this should be okay + def __hash__(self): + return id(self) + + @property + def device(self) -> str: + return self.lazydata.device + + @property + def shape(self) -> Tuple[sint, ...]: + return self.lazydata.shape + + @property + def dtype(self) -> DType: + return self.lazydata.dtype + + # ***** data handlers **** + + @staticmethod + def corealize(lst: Iterable[Tensor]): + seen: Set[LazyBuffer] = set() + sched = [] + for t in lst: + sched += t.lazydata.schedule(seen) + run_schedule(sched) + + def realize(self) -> Tensor: + run_schedule(self.lazydata.schedule()) + return self + + def assign(self, x) -> Tensor: + # TODO: this is a hack for writing to DISK. remove with working assign + if self.device.startswith("DISK"): + if x.__class__ is not Tensor: + x = Tensor(x, device="CPU", dtype=self.dtype) + self.contiguous().realize().lazydata.realized.copyin(x.numpy().data) + return self + if x.__class__ is not Tensor: + x = Tensor(x, device=self.device, dtype=self.dtype) + assert ( + self.shape == x.shape and self.device == x.device + ), f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" + assert not x.requires_grad # self requires_grad is okay? + if DEBUG >= 4: + print(f"assign {self.lazydata} <- {x.lazydata}") + if ( + self.dtype == x.dtype + and self.lazydata.realized is not None + and not getenv("DISALLOW_ASSIGN") + ): + x.lazydata.output_buffer = self.lazydata.realized + self.lazydata = x.lazydata + return self + + def detach(self) -> Tensor: + return Tensor(self.lazydata, device=self.device, requires_grad=False) + + def numpy(self) -> np.ndarray: + assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" + assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" + if 0 in self.shape: + return np.zeros(self.shape, dtype=self.dtype.np) + return ( + self.detach() + .cast(dtypes.from_np(self.dtype.np)) + .contiguous() + .to("CPU") + .realize() + .lazydata.realized.toCPU() + .astype(self.dtype.np, copy=True) + .reshape(self.shape) + ) + + def item(self) -> Union[float, int]: + assert self.numel() == 1, "must have one element for item" + return self.realize().lazydata.realized.toCPU().item() + + def to(self, device: Optional[str]) -> Tensor: + if device is None or device == self.device: + return self + ret = Tensor(self.lazydata, device) + if self.grad: + ret.grad = self.grad.to(device) + return ret + + def to_(self, device: Optional[str]): + if device is None or device == self.device: + return + if self.grad: + self.grad = self.grad.to_(device) + _ret = Tensor(self.lazydata, device) + self.lazydata = _ret.lazydata + + # ***** creation llop entrypoint ***** + + @staticmethod + def _loadop( + op, + sz, + device: Optional[str] = None, + dtype: Optional[DType] = None, + arg=None, + **kwargs, + ): + assert isinstance(sz, int), f"cannot create with symbolic size {sz}" + return Tensor( + LazyBuffer.loadop( + op, + (sz,), + Tensor.default_type if dtype is None else dtype, + Device.canonicalize(device), + arg, + ), + dtype=dtype, + device=device, + **kwargs, + ) + + @staticmethod + def empty(*shape, **kwargs): + return Tensor._loadop( + LoadOps.EMPTY, prod((shape := argfix(*shape))), **kwargs + ).reshape(shape) + + _seed: int = int(time.time()) + + @staticmethod + def manual_seed(seed=0): + Tensor._seed = seed + + @staticmethod + def rand(*shape, **kwargs): + return Tensor._loadop( + LoadOps.CUSTOM, prod((shape := argfix(*shape))), arg=custom_random, **kwargs + ).reshape(shape) + + # ***** creation helper functions ***** + + @staticmethod + def full(shape: Tuple[sint, ...], fill_value, **kwargs): + return ( + Tensor(fill_value, **kwargs) + .reshape([1] * len(new_shape := argfix(shape))) + .expand(new_shape) + ) + + @staticmethod + def zeros(*shape, **kwargs): + return Tensor.full(argfix(*shape), 0, **kwargs) + + @staticmethod + def ones(*shape, **kwargs): + return Tensor.full(argfix(*shape), 1, **kwargs) + + @staticmethod + def arange(start, stop=None, step=1, **kwargs): + if stop is None: + stop, start = start, 0 + return Tensor.full( + (math.ceil((stop - start) / step),), step, **kwargs + ).cumsum() + (start - step) + + @staticmethod + def eye(dim: int, **kwargs): + return ( + Tensor.full((dim, 1), 1, **kwargs) + .pad(((0, 0), (0, dim))) + .reshape(dim * (dim + 1)) + .shrink(((0, dim * dim),)) + .reshape(dim, dim) + ) + + def full_like(self, fill_value, **kwargs): + return Tensor.full( + self.shape, + fill_value=fill_value, + dtype=kwargs.pop("dtype", self.dtype), + device=kwargs.pop("device", self.device), + **kwargs, + ) + + def zeros_like(self, **kwargs): + return self.full_like(0, **kwargs) + + def ones_like(self, **kwargs): + return self.full_like(1, **kwargs) + + # ***** rng hlops ***** + + @staticmethod + def randn(*shape, dtype: Optional[DType] = None, **kwargs) -> Tensor: + # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + src = Tensor.rand(2, *shape, **kwargs) + return ( + src[0] + .mul(2 * math.pi) + .cos() + .mul((1 - src[1]).log().mul(-2).sqrt()) + .cast(Tensor.default_type if dtype is None else dtype) + ) + + @staticmethod + def randint(*shape, low=0, high=10, **kwargs) -> Tensor: + return (Tensor.rand(*shape, **kwargs) * (high - low) + low).cast(dtypes.int32) + + @staticmethod + def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: + return (std * Tensor.randn(*shape, **kwargs)) + mean + + @staticmethod + def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor: + dtype = kwargs.pop("dtype", Tensor.default_type) + return ((high - low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low + + @staticmethod + def scaled_uniform(*shape, **kwargs) -> Tensor: + return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul( + prod(shape) ** -0.5 + ) + + # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform + @staticmethod + def glorot_uniform(*shape, **kwargs) -> Tensor: + return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul( + (6 / (shape[0] + prod(shape[1:]))) ** 0.5 + ) + + # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ + @staticmethod + def kaiming_uniform(*shape, a: float = 0.01, **kwargs) -> Tensor: + bound = ( + math.sqrt(3.0) * math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:])) + ) + return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) + + # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ + @staticmethod + def kaiming_normal(*shape, a: float = 0.01, **kwargs) -> Tensor: + std = math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:])) + return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) + + def multinomial( + self: Tensor, num_samples: int = 1, replacement: bool = False + ) -> Tensor: + assert ( + 1 <= self.ndim <= 2 and num_samples > 0 + ), f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" + assert ( + replacement or num_samples == 1 + ), "no replacement only supports num_samples = 1" + weight = self.unsqueeze(0) if self.ndim == 1 else self + cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1) + unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) + indices = ( + (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) + ) + return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) + + # ***** toposort and backward pass ***** + + def deepwalk(self): + def _deepwalk(node, visited, nodes): + visited.add(node) + if getattr(node, "_ctx", None): + for i in node._ctx.parents: + if i not in visited: + _deepwalk(i, visited, nodes) + nodes.append(node) + return nodes + + return _deepwalk(self, set(), []) + + def backward(self) -> Tensor: + assert ( + self.shape == tuple() + ), f"backward can only be called for scalar tensors, but it has shape {self.shape})" + + # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous + # this is "implicit gradient creation" + self.grad = Tensor(1, device=self.device, requires_grad=False) + + for t0 in reversed(self.deepwalk()): + assert t0.grad is not None + grads = t0._ctx.backward(t0.grad.lazydata) + grads = [ + Tensor(g, device=self.device, requires_grad=False) + if g is not None + else None + for g in ([grads] if len(t0._ctx.parents) == 1 else grads) + ] + for t, g in zip(t0._ctx.parents, grads): + if g is not None and t.requires_grad: + assert ( + g.shape == t.shape + ), f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" + t.grad = g if t.grad is None else (t.grad + g) + del t0._ctx + return self + + # ***** movement mlops ***** + + def reshape(self, shape, *args) -> Tensor: + new_shape = argfix(shape, *args) + return mlops.Reshape.apply( + self, + shape=tuple( + [ + -prod(self.shape) // prod(new_shape) + if s == -1 + else (s if s is not None else self.shape[i]) + for i, s in enumerate(new_shape) + ] + ), + ) + + def expand(self, shape, *args) -> Tensor: + return mlops.Expand.apply( + self, + shape=tuple( + [x if x != -1 else s for s, x in zip(self.shape, argfix(shape, *args))] + ), + ) + + def permute(self, order, *args) -> Tensor: + return mlops.Permute.apply(self, order=argfix(order, *args)) + + def flip(self, axis, *args) -> Tensor: + return mlops.Flip.apply( + self, + axis=[x if x >= 0 else x + len(self.shape) for x in argfix(axis, *args)], + ) + + def shrink(self, arg: Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: + return ( + mlops.Shrink.apply( + self, + arg=tuple( + x if x is not None else (0, s) for x, s in zip(arg, self.shape) + ), + ) + if any(x is not None and x != (0, s) for x, s in zip(arg, self.shape)) + else self + ) + + def pad( + self, arg: Tuple[Optional[Tuple[sint, sint]], ...], value: float = 0.0 + ) -> Tensor: + if all(x is None or x == (0, 0) for x in arg): + return self + ret = mlops.Pad.apply( + self, arg=(narg := tuple(x if x is not None else (0, 0) for x in arg)) + ) + return ( + ret + if 0 == value + else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value) + ) + + # ***** movement hlops ***** + + # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element + # - A slice i:j returns the elements with indices in [i, j) + # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence + # - Negative values for i and j are taken relative to the end of the sequence + # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence + # - Indexing with None on a given axis will add a new dimension of size one before that axis + # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends). + # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len). + # - Strides > 1 and < 0 are now allowed!: + # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional) + # - Idea of stride < 0 support: + # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. + # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink): + # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s]. + # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] + # is possible. + # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s]. + # - Fancy indexing and combined indexing is supported + # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing + # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively + # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim + # - There's a special case where a permute is needed at the end: + # - if first Tensor passed in (expand dims) is not at dim 0 + # - and following Tensors does not follow consecutively to the end of fancy indexing's dims + def __getitem__( + self, indices + ) -> ( + Tensor + ): # indices: Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] + def normalize_int(e, i, dim_sz): + if -dim_sz <= e < dim_sz: + return e if e != -1 else dim_sz - 1 + raise IndexError( + f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}" + ) + + # TODO: if indices is a tuple of any sequence, or if indices is a list, it's for advanced indexing + orig_slices = list(indices) if isinstance(indices, tuple) else [indices] + count = defaultdict(list) + for i, v in enumerate(orig_slices): + count[type(v)].append(i) + + # TODO: boolean indices + if ( + num_slices := len(count[int]) + + len(count[slice]) + + len(count[Tensor]) + + len(count[list]) + ) > len(self.shape): + raise IndexError( + f"too many indices for tensor of dimension {len(self.shape)}" + ) + if len(ellipsis_found := count[type(Ellipsis)]) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + + # replace ellipsis with equivalent number of slice(None) + # TODO: move all slice(None) to the end and transpose non-None to the front + ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices) + orig_slices[ellipsis_idx : ellipsis_idx + 1] = [slice(None)] * ( + len(self.shape) - num_slices + ) + + valid_slices = [v for v in orig_slices if v is not None] + valid_slices = [ + v + if isinstance(v, slice) + else slice(y_ := normalize_int(v, i, dim_sz), y_ + 1) + if isinstance(v, int) + else slice(None) + for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape)) + ] + + start, stop, strides = ( + zip(*y) + if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) + else ((), (), ()) + ) + new_slice = tuple( + ((0, 0) if e < s else (s, e)) + if st > 0 + else ((0, 0) if e > s else (e + 1, s + 1)) + for s, e, st in zip(start, stop, strides) + ) + sliced_tensor = self.shrink(new_slice).flip( + axis=[i for i, s in enumerate(strides) if s < 0] + ) + new_shape = sliced_tensor.shape + if any(abs(s) != 1 for s in strides): + strides = tuple(abs(s) for s in strides) + # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded] + padded_tensor = sliced_tensor.pad( + tuple( + (0, s - (dim_sz % s) if dim_sz % s != 0 else 0) + for s, dim_sz in zip(strides, sliced_tensor.shape) + ) + ) + # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] + reshaped_tensor = padded_tensor.reshape( + flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)) + ) + new_shape = reshaped_tensor.shape[::2] + # Shrink: do [:, 0] + sliced_tensor = reshaped_tensor.shrink( + tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)) + ) + + final_shape, it_shape, dim, tensors, dim_collapsed = ( + [], + iter(new_shape), + [], + [], + 0, + ) + for i, s in enumerate(orig_slices): + if s is None: + final_shape.append(1) + else: # s is int or slice or Tensor + dim_shape = next(it_shape) + if isinstance(s, list): + s = Tensor(s) + if isinstance(s, int): + dim_collapsed += 1 + else: + assert isinstance( + dim_shape, int + ), f"does not support symbolic shape {dim_shape}" + final_shape.append(dim_shape) + if isinstance(s, Tensor): + tensors.append(s) + dim.append(i - dim_collapsed) + ret = sliced_tensor.reshape(tuple(final_shape)) + + if tensors: # Fancy/tensor indexing + # normalize idx + # TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm + idx = [ + t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t + for d, t in zip(dim, tensors) + ] + max_dim = max(i.ndim for i in idx) + # compute sum_dim, arange, and idx + sum_dim = [d if n == 0 else d + max_dim - n for n, d in enumerate(dim)] + arange = [ + Tensor.arange( + ret.shape[d], + dtype=dtypes.int32, + requires_grad=False, + device=self.device, + ).reshape( + *[1] * sd, ret.shape[d], *[1] * (ret.ndim + max_dim - n - sd - 1) + ) + for n, (sd, d) in enumerate(zip(sum_dim, dim)) + ] + first_idx = [ + idx[0].reshape( + *[1] * dim[0], + *[1] * (1 + max_dim - idx[0].ndim), + *idx[0].shape, + *[1] * (ret.ndim - dim[0] - 1), + ) + ] + rest_idx = [ + i.reshape( + *[1] * dim[0], + *[1] * (max_dim - i.ndim), + *i.shape, + *[1] * (ret.ndim - dim[0] - n), + ) + for n, i in enumerate(idx[1:], 1) + ] + idx = first_idx + rest_idx + ret = ret.reshape( + *ret.shape[: sum_dim[0] + 1], + *[1] * max_dim, + *ret.shape[sum_dim[0] + 1 :], + ) + # iteratively fancy index + for a, i, sd in zip(arange, idx, sum_dim): + ret = (a == i).mul(ret).sum(sd) + # special permute case + if ( + dim[0] != 0 + and len(dim) != 1 + and dim != list(range(dim[0], dim[-1] + 1)) + ): + ret_dims = list(range(ret.ndim)) + ret = ret.permute( + ret_dims[dim[0] : dim[0] + max_dim] + + ret_dims[: dim[0]] + + ret_dims[dim[0] + max_dim :] + ) + return ret + + def __setitem__(self, indices, v): + return self.__getitem__(indices).assign(v) + + # NOTE: using slice is discouraged and things should migrate to pad and shrink + def slice( + self, arg: Sequence[Optional[Tuple[int, sint]]], value: float = 0 + ) -> Tensor: + arg_ = tuple([a if a is not None else (0, s) for s, a in zip(self.shape, arg)]) + padding = tuple( + [(max(0, -p[0]), max(0, p[1] - self.shape[i])) for i, p in enumerate(arg_)] + ) + return self.pad(padding, value=value).shrink( + tuple( + [ + (p[0] + padding[i][0], p[1] + padding[i][0]) + for i, p in enumerate(arg_) + ] + ) + ) + + def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor: + assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" + assert all( + s >= i for s, i in zip(self.shape, idx.shape) + ), "all dim of idx.shape must be smaller than self.shape" + if dim < 0: + dim += self.ndim + idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) + permarg = list(range(self.ndim)) + permarg = ( + permarg[1:dim] + [permarg[0]] + permarg[dim + 1 :] + [permarg[dim]] + if dim != 0 + else permarg[1:] + [permarg[0]] + ) + return ( + ( + ( + idx + == Tensor.arange( + self.shape[dim], + dtype=dtypes.int32, + requires_grad=False, + device=self.device, + ) + ) + * self.permute(*permarg) + .shrink( + tuple([*[(0, sh) for sh in idx.shape[1:-1]], (0, self.shape[dim])]) + ) + .unsqueeze(0) + ) + .sum(-1) + .transpose(ax1=0, ax2=dim) + ) + + def cat(self, *args: Tensor, dim: int = 0) -> Tensor: + dim = (dim + len(self.shape)) if dim < 0 else dim + assert all( + len(y.shape) == len(self.shape) + and all(y.shape[i] == s for i, s in enumerate(self.shape) if i != dim) + for y in args + ) + catargs = [self, *args] + assert all( + t.shape for t in catargs + ), "zero-dimensional tensor cannot be concatenated" + shapes = [s.shape[dim] for s in catargs] + shape_cumsum = [0, *accumulate(shapes)] + slc: List[List[Tuple[sint, sint]]] = [ + [(0, 0) for _ in self.shape] for _ in catargs + ] + for shp, k, s in zip(shapes, shape_cumsum[:-1], slc): + s[dim] = (k, shape_cumsum[-1] - k - shp) + return reduce( + Tensor.__add__, [arg.pad(tuple(s)) for arg, s in zip(catargs, slc)] + ) + + @staticmethod + def stack(tensors: Sequence[Tensor], dim: int = 0) -> Tensor: + first = tensors[0].unsqueeze(dim) + unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]] + # checks for shapes and number of dimensions delegated to cat + return first.cat(*unsqueezed_tensors, dim=dim) + + def repeat(self, repeats: Sequence[int]) -> Tensor: + base_shape = (1,) * (len(repeats) - self.ndim) + self.shape + new_shape = [x for b in base_shape for x in [1, b]] + expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] + final_shape = [r * s for r, s in zip(repeats, base_shape)] + return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) + + def chunk(self, num: int, dim: int = 0) -> List[Tensor]: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + dim, step = dim + self.ndim if dim < 0 else dim, math.ceil( + self.shape[dim] / num + ) + slice_params = [ + [slice(None)] * dim + [slice(k, k + step)] + for k in range(0, self.shape[dim], step) + ] + return [self[tuple(sl)] for sl in slice_params] + + def squeeze(self, dim: Optional[int] = None) -> Tensor: + if dim is None: + return ( + self + if 1 not in self.shape + else self.reshape(*[size for size in self.shape if size != 1]) + ) + if dim <= 0 and self.ndim == 0: + return self # This is to match PyTorch behavior + if not -self.ndim <= dim < self.ndim: + raise IndexError( + f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})" + ) + if dim < 0: + dim += self.ndim + return ( + self + if self.shape[dim] != 1 + else self.reshape( + *[size for idx, size in enumerate(self.shape) if idx != dim] + ) + ) + + def unsqueeze(self, dim: int) -> Tensor: + if dim < 0: + dim = len(self.shape) + dim + 1 + return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) + + # (padding_left, padding_right, padding_top, padding_bottom) + def pad2d( + self, padding: Union[List[int], Tuple[int, ...]], value: float = 0 + ) -> Tensor: + slc = [ + (-p0, s + p1) + for p0, p1, s in zip(padding[::2], padding[1::2], self.shape[::-1]) + ][::-1] + return self.slice( + [(0, s) for s in self.shape[: -(len(padding) // 2)]] + slc, value=value + ) + + @property + def T(self) -> Tensor: + return self.transpose() + + def transpose(self, ax1=1, ax2=0) -> Tensor: + order = list(range(len(self.shape))) + order[ax1], order[ax2] = order[ax2], order[ax1] + return self.permute(order) + + def flatten(self, start_dim=0): + return self.reshape(shape=self.shape[:start_dim] + (-1,)) + + # ***** reduce ops ***** + + def _reduce( + self, + fxn: Type[Function], + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdim=False, + ) -> Tensor: + axis_: List[int] = ( + list(range(len(self.shape))) + if axis is None + else ([axis] if isinstance(axis, int) else list(axis)) + ) + axis_ = [x if x >= 0 else x + len(self.shape) for x in axis_] + shape = tuple(s for i, s in enumerate(self.shape) if i not in axis_) + if 0 in self.shape and 0 not in shape: + return Tensor.full( + tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, + {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn], + ) + ret = fxn.apply( + self, + new_shape=tuple([1 if i in axis_ else s for i, s in enumerate(self.shape)]), + ) + return ret if keepdim else ret.reshape(shape=shape) + + def sum(self, axis=None, keepdim=False): + return self._reduce(mlops.Sum, axis, keepdim) + + def max(self, axis=None, keepdim=False): + return self._reduce(mlops.Max, axis, keepdim) + + def min(self, axis=None, keepdim=False): + return -((-self).max(axis=axis, keepdim=keepdim)) + + def mean(self, axis=None, keepdim=False): + assert all_int(self.shape), "does not support symbolic shape" + out = self.sum(axis=axis, keepdim=keepdim) + return ( + out.mul(prod(out.shape) / prod(self.shape)) if 0 not in self.shape else out + ) + + def std(self, axis=None, keepdim=False, correction=1): + assert all_int(self.shape), "does not support symbolic shape" + square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum( + axis=axis, keepdim=keepdim + ) + return square_sum.div( + prod(self.shape) / prod(square_sum.shape) - correction + ).sqrt() + + def _softmax(self, axis): + m = self - self.max(axis=axis, keepdim=True) + e = m.exp() + return m, e, e.sum(axis=axis, keepdim=True) + + def softmax(self, axis=-1): + _, e, ss = self._softmax(axis) + return e.div(ss) + + def log_softmax(self, axis=-1): + m, _, ss = self._softmax(axis) + return m - ss.log() + + def argmax(self, axis=None, keepdim=False): + if axis is None: + idx = (self == self.max(axis)) * Tensor.arange( + prod(self.shape) - 1, + -1, + -1, + dtype=dtypes.int32, + requires_grad=False, + device=self.device, + ).reshape(self.shape) + return prod(self.shape) - idx.max() - 1 + axis = axis + len(self.shape) if axis < 0 else axis + m = self == self.max(axis=axis, keepdim=True) + idx = m * Tensor.arange( + self.shape[axis] - 1, + -1, + -1, + dtype=dtypes.int32, + requires_grad=False, + device=self.device, + ).reshape(self.shape[axis], *[1] * (self.ndim - axis - 1)) + return self.shape[axis] - idx.max(axis=axis, keepdim=keepdim) - 1 + + def argmin(self, axis=None, keepdim=False): + return (-self).argmax(axis=axis, keepdim=keepdim) + + # ***** processing ops ***** + + def _pool( + self, + k_: Tuple[sint, ...], + stride: Union[Tuple[int, ...], int] = 1, + dilation: Union[Tuple[int, ...], int] = 1, + ) -> Tensor: + assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" + assert all_int(self.shape) and all_int( + k_ + ), f"does not support symbolic {self.shape=}, {k_=}" + s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) + assert len(k_) == len(s_) and len(k_) == len( + d_ + ), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" + slc_prefix, prefix, i_ = ( + [(0, x) for x in self.shape[0 : -len(k_)]], + self.shape[0 : -len(k_)], + self.shape[-len(k_) :], + ) + if any(k > s for k, s in zip(k_, s_)) or any(d != 1 for d in d_): + o_ = [(i - d * (k - 1) - 1) // s + 1 for i, d, k, s in zip(i_, d_, k_, s_)] + e_ = [ + math.ceil(k * (i + d) / i) for k, i, d in zip(k_, i_, d_) + ] # expands such that we don't need padding + xup = ( + self.reshape(*prefix, *flatten((1, i) for i in i_)) + .expand(*prefix, *flatten((e, i) for e, i in zip(e_, i_))) + .reshape(*prefix, *[e * i for e, i in zip(e_, i_)]) + ) + # slide by dilation + xup = xup.slice( + slc_prefix + [(0, k * (i + d)) for k, i, d in zip(k_, i_, d_)] + ) + xup = xup.reshape( + *prefix, *flatten((k, i + d) for k, i, d in zip(k_, i_, d_)) + ) + xup = xup.slice( + slc_prefix + + flatten(((0, k), (0, o * s)) for k, o, s in zip(k_, o_, s_)) + ) + # handle stride, and permute to move reduce to the end + xup = xup.reshape( + *prefix, *flatten((k, o, s) for k, o, s in zip(k_, o_, s_)) + ) + xup = xup.slice( + slc_prefix + flatten(((0, k), (0, o), (0, 1)) for k, o in zip(k_, o_)) + ) + xup = xup.reshape(*prefix, *flatten((k, o) for k, o in zip(k_, o_))) + return xup.permute( + *range(len(prefix)), + *[len(prefix) + i * 2 + 1 for i in range(len(k_))], + *[len(prefix) + i * 2 for i in range(len(k_))], + ) + # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker + o_ = [(i + (s - k)) // s for i, s, k in zip(i_, s_, k_)] + xup = self.slice(slc_prefix + [(0, o * s) for o, s in zip(o_, s_)]) + xup = xup.reshape(*prefix, *flatten(((o, s) for o, s in zip(o_, s_)))) + xup = xup.slice(slc_prefix + flatten(((0, o), (0, k)) for o, k in zip(o_, k_))) + return xup.permute( + *range(len(prefix)), + *[len(prefix) + i * 2 for i in range(len(k_))], + *[len(prefix) + i * 2 + 1 for i in range(len(k_))], + ) + + # NOTE: these work for more than 2D + def avg_pool2d(self, kernel_size=(2, 2), stride=None, dilation=1): + return self._pool( + make_pair(kernel_size), + stride if stride is not None else kernel_size, + dilation, + ).mean(axis=tuple(range(0 - len(make_pair(kernel_size)), 0))) + + def max_pool2d(self, kernel_size=(2, 2), stride=None, dilation=1): + return self._pool( + make_pair(kernel_size), + stride if stride is not None else kernel_size, + dilation, + ).max(axis=tuple(range(0 - len(make_pair(kernel_size)), 0))) + + def conv_transpose2d( + self, + weight: Tensor, + bias: Optional[Tensor] = None, + groups=1, + stride=1, + dilation=1, + padding=0, + output_padding=0, + ) -> Tensor: + HW, trailing = weight.shape[2:], list(range(3, len(weight.shape) + 1)) + x, w = self, weight.reshape( + groups, weight.shape[0] // groups, weight.shape[1], *weight.shape[2:] + ).permute(0, 2, 1, *trailing).flip(trailing) + stride = make_pair(stride, len(HW)) + if any(s > 1 for s in stride): + x = x.reshape(*x.shape[:2], *flatten((k, 1) for k in x.shape[2:])) + x = x.pad(((0, 0), (0, 0), *flatten(((0, 0), (0, s - 1)) for s in stride))) + x = x.reshape(*x.shape[:2], *[k * s for k, s in zip(x.shape[2::2], stride)]) + x = x.shrink( + ( + (0, x.shape[0]), + (0, x.shape[1]), + *[(0, k - (s - 1)) for k, s in zip(x.shape[2:], stride)], + ) + ) + padding = flatten( + ( + ((k - 1) * d - p, (k - 1) * d - p + op) + for k, d, p, op in reversed( + list( + zip( + HW, + make_pair(dilation, len(HW)), + make_pair(padding, len(HW)), + make_pair(output_padding, len(HW)), + ) + ) + ) + ) + ) + return x.conv2d( + w.reshape(w.shape[0] * w.shape[1], *w.shape[2:]), + groups=groups, + bias=bias, + dilation=dilation, + padding=padding, + ) + + wino = int(getenv("WINO", "0")) + + def conv2d( + self, + weight: Tensor, + bias: Optional[Tensor] = None, + groups=1, + stride=1, + dilation=1, + padding=0, + ) -> Tensor: + (bs, cin_), (cout, cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] + assert groups * cin == cin_ and len(self.shape) == len( + weight.shape + ), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" + if isinstance(padding, (tuple, list)): + assert len(padding) == 2 * len(HW) or len(padding) == len( + HW + ), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" + padding_ = ( + [padding] * 2 * len(HW) + if isinstance(padding, int) + else ( + padding + if len(padding) == 2 * len(HW) + else [p for p in padding for _ in range(2)][::-1] + ) + ) + + # conv2d is a pooling op (with padding) + x = self.pad2d(padding_)._pool( + HW, stride, dilation + ) # (bs, groups*cin, oy, ox, H, W) + rcout, oyx = cout // groups, x.shape[2 : -len(HW)] + if ( + not all(x == 3 for x in HW) + or stride != 1 + or dilation != 1 + or not Tensor.wino + ): + # normal conv + x = ( + x.reshape(bs, groups, cin, 1, *oyx, *HW) + .expand(bs, groups, cin, rcout, *oyx, *HW) + .permute( + 0, + 1, + 3, + *[4 + i for i in range(len(oyx))], + 2, + *[4 + len(oyx) + i for i in range(len(HW))], + ) + ) + + # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) + ret = ( + (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)) + .sum([-1 - i for i in range(1 + len(oyx))], keepdim=True) + .reshape(bs, cout, *oyx) + ) + return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) + + # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 + def apply_matrix(mat, t, dim=0): + return ( + t + if dim == len(HW) + else Tensor.stack( + [ + apply_matrix( + mat, + sum(mm * t[j] for j, mm in enumerate(m) if mm), + dim=dim + 1, + ) + for m in mat + ] + ) + ) + + HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles + winograd_Bt = [ + [4, 0, -5, 0, 1, 0], + [0, -4, -4, 1, 1, 0], + [0, 4, -4, -1, 1, 0], + [0, -2, -1, 2, 1, 0], + [0, 2, -1, -2, 1, 0], + [0, 4, 0, -5, 0, 1], + ] + winograd_G = [ + [1 / 4, 0, 0], + [-1 / 6, -1 / 6, -1 / 6], + [-1 / 6, 1 / 6, -1 / 6], + [1 / 24, 1 / 12, 1 / 6], + [1 / 24, -1 / 12, 1 / 6], + [0, 0, 1], + ] + winograd_At = [ + [1, 1, 1, 1, 1, 0], + [0, 1, -1, 2, -2, 0], + [0, 1, 1, 4, 4, 0], + [0, 1, -1, 8, -8, 1], + ] # applying At in pre-order almost doubles compilation time + + # todo: stride == dilation + # use padding to round up to 4x4 output tiles + d = self.pad2d( + sum( + [ + [ + padding_[i * 2], + padding_[i * 2 + 1] + + (-(dim + sum(padding_[i * 2 : (i + 1) * 2]) - 2) % 4), + ] + for i, dim in enumerate(self.shape[-len(HW) :]) + ], + [], + ) + )._pool( + HWI, HWO + ) # (bs, cin_, tyx, HWI) + d = d.permute( + *range(len(d.shape) - len(HW), len(d.shape)), *range(len(d.shape) - len(HW)) + ).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx) + tyx = d.shape[-len(HWI) :] # dim of tiling + + g = weight.permute( + *range(len(weight.shape) - len(HW), len(weight.shape)), + *range(len(weight.shape) - len(HW)), + ) # move HW to the front + + # compute 6x6 winograd tiles: GgGt, BtdB + gfactors = ( + apply_matrix(winograd_G, g) + .contiguous() + .reshape(*HWI, 1, groups, rcout, cin, *([1] * len(tyx))) + ) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) + dfactors = ( + apply_matrix(winograd_Bt, d) + .contiguous() + .reshape(*HWI, bs, groups, 1, cin, *tyx) + ) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) + + ret = apply_matrix( + winograd_At, (gfactors * dfactors).sum(axis=-1 - len(HW)) + ) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) + + ret = ret.permute( + [ + *range(len(HW), len(ret.shape) - len(HW)), + *[i + o for i in range(len(HW)) for o in [len(ret.shape) - len(HW), 0]], + ] + ) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) + ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink( + tuple((0, s) for s in [bs, cout, *oyx]) + ) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final + + return ( + ( + ret + if bias is None + else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))])) + ) + .contiguous() + .contiguous_backward() + ) + + def dot(self, w: Tensor) -> Tensor: + n1, n2 = len(self.shape), len(w.shape) + assert ( + n1 != 0 and n2 != 0 + ), f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" + assert ( + self.shape[-1] == w.shape[-min(n2, 2)] + ), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" + x = self.reshape( + *self.shape[0:-1], *[1] * min(n1 - 1, n2 - 1, 1), self.shape[-1] + ) + w = w.reshape( + *w.shape[0:-2], *[1] * min(n1 - 1, n2 - 1, 1), *w.shape[-min(n2, 2) :] + ).transpose(-1, -min(n2, 2)) + return (x * w).sum(-1) + + def _cumsum(self, axis: int = 0, _first_zero=False) -> Tensor: + return ( + self.transpose(axis, -1) + .pad2d((self.shape[axis] - int(not _first_zero), 0)) + ._pool((self.shape[axis],)) + .sum(-1) + .transpose(axis, -1) + ) + + def cumsum(self, axis: int = 0) -> Tensor: + # TODO: someday the optimizer will find this on it's own + # for now this is a two stage cumsum + SPLIT = 256 + if self.shape[axis] <= SPLIT * 2: + return self._cumsum(axis) + ret = self.transpose(axis, -1).pad2d( + (round_up(self.shape[axis], SPLIT) - self.shape[axis], 0) + ) + ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1] // SPLIT, SPLIT)._cumsum(-1) + base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1] + base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) + + def fix(x: Tensor): + return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[ + ..., -self.shape[axis] : + ].transpose(axis, -1) + + return fix(ret) + fix(base_add) + + @staticmethod + def _tri(r: int, c: int, k: int = 0, **kwargs) -> Tensor: + return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r, c) <= Tensor.arange( + -k, c - k, **kwargs + ).unsqueeze(0).expand(r, c) + + def triu(self, k: int = 0) -> Tensor: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + return Tensor._tri( + self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device + ).where(self, Tensor.zeros_like(self)) + + def tril(self, k: int = 0) -> Tensor: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + return Tensor._tri( + self.shape[-2], + self.shape[-1], + k=k + 1, + dtype=self.dtype, + device=self.device, + ).where(Tensor.zeros_like(self), self) + + # ***** mlops (unary) ***** + + def neg(self): + return mlops.Neg.apply(self) + + def contiguous(self): + return mlops.Contiguous.apply(self) + + def contiguous_backward(self): + return mlops.ContiguousBackward.apply(self) + + def log(self): + return mlops.Log.apply(self) + + def log2(self): + return mlops.Log.apply(self) / math.log(2) + + def exp(self): + return mlops.Exp.apply(self) + + def exp2(self): + return mlops.Exp.apply(self * math.log(2)) + + def relu(self): + return mlops.Relu.apply(self) + + def sigmoid(self): + return mlops.Sigmoid.apply(self) + + def sin(self): + return mlops.Sin.apply(self) + + def sqrt(self): + return mlops.Sqrt.apply(self) + + def rsqrt(self): + return (1 / self).sqrt() + + def cos(self): + return ((math.pi / 2) - self).sin() + + def tan(self): + return self.sin() / self.cos() + + # ***** math functions (unary) ***** + + def trunc(self: Tensor) -> Tensor: + return self.cast(dtypes.int32).contiguous().cast(self.dtype) + + def ceil(self: Tensor) -> Tensor: + return (self > (b := self.trunc())).where(b + 1, b) + + def floor(self: Tensor) -> Tensor: + return (self < (b := self.trunc())).where(b - 1, b) + + def square(self): + return self * self + + def clip(self, min_, max_): + return self.maximum(min_).minimum(max_) + + def abs(self): + return self.relu() + (-self).relu() + + def sign(self): + return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) + + def reciprocal(self): + return 1.0 / self + + # ***** activation functions (unary) ***** + + def elu(self, alpha=1.0): + return self.relu() - alpha * (1 - self.exp()).relu() + + def celu(self, alpha=1.0): + return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) + + def swish(self): + return self * self.sigmoid() + + def silu(self): + return self.swish() # The SiLU function is also known as the swish function. + + def relu6(self): + return self.relu() - (self - 6).relu() + + def hardswish(self): + return self * (self + 3).relu6() * (1 / 6) + + def tanh(self): + return 2.0 * ((2.0 * self).sigmoid()) - 1.0 + + def sinh(self): + return (self.exp() - self.neg().exp()) / 2 + + def cosh(self): + return (self.exp() + self.neg().exp()) / 2 + + def atanh(self): + return ((1 + self) / (1 - self)).log() / 2 + + def asinh(self): + return (self + (self.square() + 1).sqrt()).log() + + def acosh(self): + return (self + (self.square() - 1).sqrt()).log() + + def hardtanh(self, min_val=-1, max_val=1): + return self.clip(min_val, max_val) + + def gelu(self): + return ( + 0.5 + * self + * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh()) + ) + + def quick_gelu(self): + return self * (self * 1.702).sigmoid() + + def leakyrelu(self, neg_slope=0.01): + return self.relu() - (-neg_slope * self).relu() + + def mish(self): + return self * self.softplus().tanh() + + def softplus(self, beta=1): + return (1 / beta) * (1 + (self * beta).exp()).log() + + def softsign(self): + return self / (1 + self.abs()) + + # ***** broadcasted binary mlops ***** + + def _broadcasted( + self, y: Union[Tensor, float], reverse: bool = False + ) -> Tuple[Tensor, Tensor]: + x: Tensor = self + if not isinstance(y, Tensor): + if 0 in x.shape: + return x, x.full_like(y) + y = Tensor( + y, + device=self.device, + requires_grad=False, + dtype=self.dtype + if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType + else dtypes.float32, + ) + if reverse: + x, y = y, x + if (xshape := x.shape) == (yshape := y.shape): + return (x, y) + + shape_delta = len(xshape) - len(yshape) + if shape_delta > 0: + y = y.reshape((1,) * shape_delta + yshape) + elif shape_delta < 0: + x = x.reshape((1,) * -shape_delta + xshape) + if (xshape := x.shape) == (yshape := y.shape): + return (x, y) + + shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)]) + if xshape != shape_ret: + x = x.expand(shape_ret) + if yshape != shape_ret: + y = y.expand(shape_ret) + return (x, y) + + def _to_float(self, x: Union[Tensor, float]): + return ( + x.lazydata.base.op.arg + if isinstance(x, Tensor) + and x.lazydata.is_unrealized_contiguous_const() + and not x.requires_grad + and self._broadcasted(x)[0].shape == self.shape + else x + ) + + def add(self, x: Union[Tensor, float], reverse=False) -> Tensor: + x = self._to_float(x) + return ( + mlops.Add.apply(*self._broadcasted(x, reverse)) + if x.__class__ is Tensor or x + else self + ) + + def sub(self, x: Union[Tensor, float], reverse=False) -> Tensor: + x = self._to_float(x) + return ( + mlops.Sub.apply(*self._broadcasted(x, reverse)) + if x.__class__ is Tensor or x + else (-self if reverse else self) + ) + + def mul(self, x: Union[Tensor, float], reverse=False) -> Tensor: + x = self._to_float(x) + if x.__class__ is not Tensor and x == 0.0: + return mlops.Zero.apply(self) + if x.__class__ is not Tensor and x == -1.0: + return -self + return ( + mlops.Mul.apply(*self._broadcasted(x, reverse)) + if x.__class__ is Tensor or x != 1.0 + else self + ) + + def div(self, x: Union[Tensor, float], reverse=False) -> Tensor: + x = self._to_float(x) + return ( + mlops.Div.apply(*self._broadcasted(x, reverse)) + if x.__class__ is Tensor + or reverse + or not x + or not dtypes.is_float(self.dtype) + else self.mul(1 / x) + ) + + def pow(self, x: Union[Tensor, float], reverse=False) -> Tensor: + x = self._to_float(x) + if x.__class__ is not Tensor and not reverse: + # simple pow identities + if x < 0: + return self.reciprocal().pow(-x) + if x == 3.0: + return self * self * self + if x == 2.0: + return self * self + if x == 1.0: + return self + if x == 0.5: + return self.sqrt() + if not isinstance(x, Tensor) and reverse and x > 0: + return self.mul(math.log(x)).exp() + ar = ( + self.abs().log().mul(x).exp() + if not reverse or isinstance(x, Tensor) + else self.mul(math.log(abs(x))).exp() + ) + # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power) + sign = ( + (x * math.pi).cos() + if isinstance(x, Tensor) + else math.cos(x * math.pi) + if not reverse + else (self * math.pi).cos() + ) + # we only need to correct the sign if the base is negative + base_sign = ( + ( + self.sign() + if not reverse + else x.sign() + if isinstance(x, Tensor) + else math.copysign(1, x) + ) + - 1 + ) / -2 + # we need 0 to be positive so we need to correct base_sign when the base is 0 + base_sign = base_sign - ( + 1.5 + * ( + 1 + - ( + self.sign().abs() + if not reverse + else x.sign().abs() + if isinstance(x, Tensor) + else abs(int(bool(x))) + ) + ) + ) + # inject nan if the base is negative and the power is not an integer + to_nan = ( + ((x - x.trunc()) * 1e10).abs().clip(0, 1) + if isinstance(x, Tensor) + else int(bool(x - int(x))) + if not reverse + else ((self - self.trunc()) * 1e10).abs().clip(0, 1) + ) * base_sign + inject_nan = ( + ((((-to_nan) * 2) + 1)).log().add(1) + if isinstance(to_nan, Tensor) + else 1 + if not to_nan + else float("nan") + ) + return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) + + def matmul(self, x: Tensor, reverse=False) -> Tensor: + return x.dot(self) if reverse else self.dot(x) + + def maximum(self, x: Union[Tensor, float]) -> Tensor: + return ( + (self < x) + .detach() + .where(x, (self > x).detach().where(self, (self + x) / 2)) + ) + + def minimum(self, x: Union[Tensor, float]) -> Tensor: + return -((-self).maximum(-x)) + + def where(self: Tensor, input_: Union[Tensor, float], other: Union[Tensor, float]): + x_, y = self._broadcasted(input_) + x, z = x_._broadcasted(other) + return mlops.Where.apply(x, *y._broadcasted(z)) + + # ***** op wrappers (wasted lines to make the typechecker happy) ***** + + def __neg__(self) -> Tensor: + return self.neg() + + def __add__(self, x) -> Tensor: + return self.add(x) + + def __sub__(self, x) -> Tensor: + return self.sub(x) + + def __mul__(self, x) -> Tensor: + return self.mul(x) + + def __pow__(self, x) -> Tensor: + return self.pow(x) + + def __truediv__(self, x) -> Tensor: + return self.div(x) + + def __matmul__(self, x) -> Tensor: + return self.matmul(x) + + def __radd__(self, x) -> Tensor: + return self.add(x, True) + + def __rsub__(self, x) -> Tensor: + return self.sub(x, True) + + def __rmul__(self, x) -> Tensor: + return self.mul(x, True) + + def __rpow__(self, x) -> Tensor: + return self.pow(x, True) + + def __rtruediv__(self, x) -> Tensor: + return self.div(x, True) + + def __rmatmul__(self, x) -> Tensor: + return self.matmul(x, True) + + def __iadd__(self, x) -> Tensor: + return self.assign(self.add(x)) + + def __isub__(self, x) -> Tensor: + return self.assign(self.sub(x)) + + def __imul__(self, x) -> Tensor: + return self.assign(self.mul(x)) + + def __ipow__(self, x) -> Tensor: + return self.assign(self.pow(x)) + + def __itruediv__(self, x) -> Tensor: + return self.assign(self.div(x)) + + def __imatmul__(self, x) -> Tensor: + return self.assign(self.matmul(x)) + + def __lt__(self, x) -> Tensor: + return mlops.Less.apply(*self._broadcasted(x, False)) + + def __gt__(self, x) -> Tensor: + return mlops.Less.apply(*self._broadcasted(x, True)) + + def __ge__(self, x) -> Tensor: + return 1.0 - (self < x) + + def __le__(self, x) -> Tensor: + return 1.0 - (self > x) + + def __ne__(self, x) -> Tensor: + return (self < x) + (self > x) # type: ignore[override] + + def __eq__(self, x) -> Tensor: + return 1.0 - (self != x) # type: ignore[override] + + # ***** functional nn ops ***** + + def linear(self, weight: Tensor, bias: Optional[Tensor] = None): + x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) + return x.add(bias) if bias is not None else x + + def sequential(self, ll: List[Callable[[Tensor], Tensor]]): + return reduce(lambda x, f: f(x), ll, self) + + def layernorm(self, axis=-1, eps: float = 1e-5) -> Tensor: + y = self - self.mean(axis, keepdim=True) + return y.mul((y * y).mean(axis, keepdim=True).add(eps).rsqrt()) + + def batchnorm( + self, + weight: Optional[Tensor], + bias: Optional[Tensor], + mean: Tensor, + invstd: Tensor, + ) -> Tensor: + x = self - mean.reshape(shape=[1, -1, 1, 1]) + if weight: + x = x * weight.reshape(shape=[1, -1, 1, 1]) + ret = x.mul( + invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd + ) + return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret + + def dropout(self, p=0.5) -> Tensor: + if not Tensor.training or p == 0: + return self + mask = ( + Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p + ).cast(dtypes.bool) + return self * mask * (1 / (1.0 - p)) + + def scaled_dot_product_attention( + self, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> Tensor: + # NOTE: it works if key, value have symbolic shape + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + if is_causal: + attn_mask = ( + Tensor.ones( + self.shape[-2], + key.shape[-2], + requires_grad=False, + device=self.device, + ) + .tril(0) + .cast(dtypes.bool) + ) + if attn_mask is not None and attn_mask.dtype == dtypes.bool: + attn_mask = (attn_mask == 0).where(-float("inf"), 0) + return ( + self @ key.transpose(-2, -1) / math.sqrt(self.shape[-1]) + attn_mask + ).softmax(-1).dropout(dropout_p) @ value + + def binary_crossentropy(self, y: Tensor) -> Tensor: + return (-y * self.log() - (1 - y) * (1 - self).log()).mean() + + def binary_crossentropy_logits(self, y: Tensor) -> Tensor: + return ( + self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log() + ).mean() + + def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: + # NOTE: self is a logits input + loss_mask = Y != ignore_index + y_counter = ( + Tensor.arange( + self.shape[-1], + dtype=dtypes.int32, + requires_grad=False, + device=self.device, + ) + .unsqueeze(0) + .expand(Y.numel(), self.shape[-1]) + ) + y = ( + (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) + * loss_mask.reshape(-1, 1) + ).reshape(*Y.shape, self.shape[-1]) + return self.log_softmax().mul(y).sum() / loss_mask.sum() + + # ***** cast ops ***** + + def cast(self, dtype: DType) -> Tensor: + return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self + + def bitcast(self, dtype: DType) -> Tensor: + assert ( + self.dtype.itemsize == dtype.itemsize + ), "can't bitcast mismatched dtype itemsizes" + return ( + mlops.Cast.apply(self, dtype=dtype, bitcast=True) + if self.dtype != dtype + else self + ) + + def float(self) -> Tensor: + return self.cast(dtypes.float32) + + def half(self) -> Tensor: + return self.cast(dtypes.float16) + + # ***** convenience stuff ***** + + @property + def ndim(self) -> int: + return len(self.shape) + + def numel(self) -> sint: + return prod(self.shape) + + def element_size(self) -> int: + return self.dtype.itemsize + + def nbytes(self) -> int: + return self.numel() * self.element_size() + + def is_floating_point(self) -> bool: + return dtypes.is_float(self.dtype) + # register functions to move between devices -for device in Device._buffers: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device)) +for device in Device._buffers: + setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device)) if IMAGE: - # if IMAGE>0 we install these replacement functions in Tensor (hack!) - from tinygrad.features.image import image_conv2d, image_dot - setattr(Tensor, "conv2d", image_conv2d) - setattr(Tensor, "dot", image_dot) + # if IMAGE>0 we install these replacement functions in Tensor (hack!) + from tinygrad.features.image import image_conv2d, image_dot + + setattr(Tensor, "conv2d", image_conv2d) + setattr(Tensor, "dot", image_dot) + # TODO: remove the custom op and replace with threefry -def custom_random(out:Buffer): - Tensor._seed += 1 - if DEBUG >= 2: print(f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}") - rng = np.random.default_rng(Tensor._seed) - rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False) - out.copyin(rng_np_buffer.data) +def custom_random(out: Buffer): + Tensor._seed += 1 + if DEBUG >= 2: + print( + f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}" + ) + rng = np.random.default_rng(Tensor._seed) + rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype( + dtype=out.dtype.np, copy=False + ) + out.copyin(rng_np_buffer.data)