656 lines
35 KiB
Python
656 lines
35 KiB
Python
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
|
from __future__ import annotations
|
|
import time
|
|
from functools import partialmethod, reduce
|
|
from itertools import accumulate, filterfalse
|
|
import operator
|
|
import numpy as np
|
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
|
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
|
from math import ceil, pi, prod, sqrt, log, cos, copysign
|
|
from tinygrad.lazy import Device, LazyBuffer
|
|
from tinygrad.ops import LoadOps
|
|
|
|
# An instantiation of the Function is the Context
|
|
class Function:
|
|
def __init__(self, device:str, *tensors:Tensor):
|
|
self.device = device
|
|
self.needs_input_grad = [t.requires_grad for t in tensors]
|
|
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
|
if self.requires_grad: self.parents = tensors
|
|
|
|
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
|
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
|
|
|
@classmethod
|
|
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
|
ctx = fxn(x[0].device, *x)
|
|
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
|
|
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
|
|
return ret
|
|
|
|
import tinygrad.mlops as mlops
|
|
|
|
# **** start with two base classes, Tensor and Function ****
|
|
|
|
class Tensor:
|
|
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
|
__deletable__ = ('_ctx',)
|
|
training: ClassVar[bool] = False
|
|
no_grad: ClassVar[bool] = False
|
|
default_type: ClassVar[DType] = dtypes.float32
|
|
|
|
def __init__(self, data:Union[int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
|
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
|
device = Device.canonicalize(device)
|
|
# tensors have gradients, buffers do not
|
|
self.grad: Optional[Tensor] = None
|
|
|
|
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
|
# None (the default) will be updated to True if it's put in an optimizer
|
|
self.requires_grad: Optional[bool] = requires_grad
|
|
|
|
# internal variables used for autograd graph construction
|
|
self._ctx: Optional[Function] = None
|
|
if data.__class__ is LazyBuffer:
|
|
data = cast(LazyBuffer, data) # NOTE: this is a noop, it makes mypy happy
|
|
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
|
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
|
return
|
|
|
|
if isinstance(data, (int, float)):
|
|
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
|
return
|
|
|
|
if data.__class__ is list:
|
|
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
|
|
|
|
if data.__class__ is np.ndarray:
|
|
data = cast(np.ndarray, data)
|
|
if data.size == 1: # constant fold
|
|
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtypes.from_np(data.dtype), device, data.flat[0]).reshape(data.shape)
|
|
else:
|
|
data = LazyBuffer.fromCPU(data)
|
|
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
|
return
|
|
|
|
raise RuntimeError(f"can't create Tensor from {data}")
|
|
|
|
def __repr__(self):
|
|
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
|
|
|
# Python has a non moving GC, so this should be okay
|
|
def __hash__(self): return id(self)
|
|
|
|
@property
|
|
def device(self) -> str: return self.lazydata.device
|
|
|
|
@property
|
|
def shape(self) -> Tuple[int, ...]: return self.lazydata.shape
|
|
|
|
@property
|
|
def dtype(self) -> DType: return self.lazydata.dtype
|
|
|
|
# ***** data handlers ****
|
|
|
|
def realize(self) -> Tensor:
|
|
self.lazydata.realize()
|
|
return self
|
|
|
|
def assign(self, x) -> Tensor:
|
|
# TODO: this is a hack for writing to DISK
|
|
if self.device.startswith("DISK"):
|
|
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
|
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
|
|
return self
|
|
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
|
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
|
|
assert not x.requires_grad # self requires_grad is okay?
|
|
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
|
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
|
self.lazydata = x.lazydata
|
|
return self
|
|
|
|
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
|
def numpy(self) -> np.ndarray: return self.lazydata.toCPU()
|
|
|
|
# TODO: if things are realized this won't work
|
|
def to_(self, device:str):
|
|
assert self.lazydata.realized is None
|
|
self.lazydata.device = device
|
|
if self.grad: self.grad.to_(device)
|
|
|
|
def to(self, device:str):
|
|
ret = Tensor(self.lazydata, device)
|
|
if self.grad: ret.grad = self.grad.to(device)
|
|
return ret
|
|
|
|
# ***** creation llop entrypoint *****
|
|
|
|
@staticmethod
|
|
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
|
|
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
|
|
|
@staticmethod
|
|
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
|
|
|
|
_seed: int = int(time.time())
|
|
@staticmethod
|
|
def manual_seed(seed=0): Tensor._seed = seed
|
|
|
|
@staticmethod
|
|
def rand(*shape, **kwargs):
|
|
Tensor._seed += 1
|
|
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
|
|
|
# ***** creation helper functions *****
|
|
|
|
@staticmethod
|
|
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
|
|
|
@staticmethod
|
|
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
|
|
|
@staticmethod
|
|
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
|
|
|
@staticmethod
|
|
def arange(stop, start=0, step=1, **kwargs): return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
|
|
|
@staticmethod
|
|
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
|
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
|
|
|
|
@staticmethod
|
|
def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
|
|
|
|
@staticmethod
|
|
def ones_like(tensor, **kwargs): return Tensor.full_like(tensor, 1, **kwargs)
|
|
|
|
@staticmethod
|
|
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
|
|
|
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
|
cond = (self != 0.0)
|
|
return cond * input_ + (1.0 - cond) * other
|
|
|
|
# ***** rng hlops *****
|
|
|
|
@staticmethod
|
|
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
|
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
|
src = Tensor.rand(2, *shape, **kwargs)
|
|
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
|
|
|
@staticmethod
|
|
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
|
|
|
@staticmethod
|
|
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
|
|
|
|
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
|
@staticmethod
|
|
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
|
|
|
|
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
|
@staticmethod
|
|
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
|
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
|
|
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
|
|
|
# ***** toposort and backward pass *****
|
|
def deepwalk(self):
|
|
def _deepwalk(node, visited, nodes):
|
|
visited.add(node)
|
|
if node._ctx:
|
|
for i in node._ctx.parents:
|
|
if i not in visited: _deepwalk(i, visited, nodes)
|
|
nodes.append(node)
|
|
return nodes
|
|
return _deepwalk(self, set(), [])
|
|
|
|
def backward(self):
|
|
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
|
|
|
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
|
# this is "implicit gradient creation"
|
|
self.grad = Tensor(1, device=self.device, requires_grad=False)
|
|
|
|
for t0 in reversed(self.deepwalk()):
|
|
if not t0.requires_grad:
|
|
del t0._ctx # TODO: does it help to delete this here ever?
|
|
continue
|
|
assert (t0.grad is not None)
|
|
grads = t0._ctx.backward(t0.grad.lazydata)
|
|
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
|
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
|
for t, g in zip(t0._ctx.parents, grads):
|
|
if g is not None and t.requires_grad:
|
|
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
|
t.grad = g if t.grad is None else (t.grad + g)
|
|
del t0._ctx
|
|
|
|
# ***** movement mlops *****
|
|
def reshape(self, shape, *args) -> Tensor:
|
|
new_shape = argfix(shape, *args)
|
|
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
|
|
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
|
|
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
|
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
|
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
|
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
|
|
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
|
|
|
# ***** movement hlops *****
|
|
|
|
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
|
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
|
|
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
|
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
|
return self.pad(padding).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
|
|
|
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
|
# - A slice i:j returns the elements with indices in [i, j)
|
|
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
|
|
# - Negative values for i and j are taken relative to the end of the sequence
|
|
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
|
|
# - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis
|
|
# - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
|
|
# - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
|
|
# - Strides > 1 and < 0 are now allowed!:
|
|
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
|
|
# - Idea of stride < 0 support:
|
|
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
|
|
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
|
|
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
|
|
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
|
|
# is possible.
|
|
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
|
|
def __getitem__(self, val):
|
|
def normalize_int(e, i, dim_sz):
|
|
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
|
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
|
val = list(val) if isinstance(val, tuple) else [val]
|
|
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
|
|
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
|
orig_slices = list(val)
|
|
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
|
|
if len(ellipses_found) > 0:
|
|
if len(ellipses_found) != 1:
|
|
raise IndexError("an index can only have a single ellipsis ('...')")
|
|
ellipsis_idx = ellipses_found[0]
|
|
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
|
else:
|
|
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
|
|
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
|
|
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
|
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
|
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
|
|
new_shape = tuple(e - s for s, e in new_slice)
|
|
# Shrink
|
|
sliced_tensor = self.shrink(new_slice)
|
|
# Flip
|
|
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
|
|
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
|
|
if any(s > 1 or s < 0 for s in strides):
|
|
# normalize if negative strides
|
|
strides = tuple(abs(s) for s in strides)
|
|
def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0 else (step - y)
|
|
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
|
|
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
|
|
padded_tensor = sliced_tensor.pad(paddings)
|
|
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
|
|
new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
|
|
reshaped_tensor = padded_tensor.reshape(new_shape)
|
|
# Shrink: do [:, 0]
|
|
new_shape = new_shape[::2]
|
|
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
|
|
sliced_tensor = reshaped_tensor.shrink(final_slice)
|
|
final_shape = []
|
|
it_shape = iter(new_shape)
|
|
for i in orig_slices:
|
|
if isinstance(i, (int, slice)):
|
|
dim_shape = next(it_shape)
|
|
if isinstance(i, slice): final_shape.append(dim_shape)
|
|
else: # i is None
|
|
final_shape.append(1)
|
|
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
|
|
|
|
def cat(self, *args, dim=0):
|
|
dim = (dim + len(self.shape)) if dim < 0 else dim
|
|
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
|
catargs = [self] + list(args)
|
|
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
|
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
|
|
slc = [[(0, s) for s in self.shape] for _ in catargs]
|
|
for s,k in zip(slc, shape_cumsum):
|
|
s[dim] = (-k, shape_cumsum[-1]-k)
|
|
return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
|
|
|
|
@staticmethod
|
|
def stack(tensors, dim=0):
|
|
first = tensors[0].unsqueeze(dim)
|
|
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
|
|
# checks for shapes and number of dimensions delegated to cat
|
|
return first.cat(*unsqueezed_tensors, dim=dim)
|
|
|
|
def repeat(self, repeats):
|
|
base_shape = self.shape
|
|
if len(repeats) > self.ndim:
|
|
base_shape = (1,) * (len(repeats) - self.ndim) + base_shape
|
|
new_shape = [x for i in range(len(base_shape)) for x in [1, base_shape[i]]]
|
|
expand_shape = [x for r,s in zip(repeats, base_shape) for x in [r,s]]
|
|
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
|
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
|
|
|
# TODO: make this nicer with syntactic sugar in slice
|
|
def chunk(self, num, dim):
|
|
slice_params = [[(0, s) for s in self.shape] for _ in range(num)]
|
|
for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)):
|
|
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
|
|
return [self.slice(p) for p in slice_params]
|
|
|
|
def unsqueeze(self, dim):
|
|
if dim < 0: dim = len(self.shape) + dim + 1
|
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
|
|
|
# (padding_left, padding_right, padding_top, padding_bottom)
|
|
def pad2d(self, padding:Union[List[int], Tuple[int, ...]]):
|
|
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
|
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc)
|
|
|
|
@property
|
|
def T(self) -> Tensor: return self.transpose()
|
|
def transpose(self, ax1=1, ax2=0) -> Tensor:
|
|
order = list(range(len(self.shape)))
|
|
order[ax1], order[ax2] = order[ax2], order[ax1]
|
|
return self.permute(order)
|
|
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
|
|
|
# ***** reduce ops *****
|
|
|
|
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
|
|
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
|
|
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
|
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
|
|
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
|
|
return ret if keepdim else ret.reshape(shape=shape)
|
|
|
|
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
|
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
|
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
|
|
|
def mean(self, axis=None, keepdim=False):
|
|
out = self.sum(axis=axis, keepdim=keepdim)
|
|
return out * (prod(out.shape)/prod(self.shape))
|
|
def std(self, axis=None, keepdim=False, correction=1):
|
|
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
|
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
|
|
def _softmax(self, axis):
|
|
m = self - self.max(axis=axis, keepdim=True)
|
|
e = m.exp()
|
|
return m, e, e.sum(axis=axis, keepdim=True)
|
|
|
|
def softmax(self, axis=-1):
|
|
_, e, ss = self._softmax(axis)
|
|
return e.div(ss)
|
|
|
|
def log_softmax(self, axis=-1):
|
|
m, _, ss = self._softmax(axis)
|
|
return m - ss.log()
|
|
|
|
# ***** processing ops *****
|
|
|
|
def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1, _insert_dims=tuple()) -> Tensor:
|
|
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
|
|
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
|
|
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
|
|
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
|
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
|
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
|
e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
|
xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)])
|
|
# NOTE: _insert_dims is required because reduces can't be merged (yet)
|
|
prefix += _insert_dims
|
|
slc_prefix += [(0,x) for x in _insert_dims]
|
|
# slide by dilation
|
|
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
|
xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
|
|
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))
|
|
# handle stride, and permute to move reduce to the end
|
|
xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
|
|
xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))
|
|
xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_)))
|
|
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))])
|
|
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
|
|
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
|
|
xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)])
|
|
xup = xup.reshape(*prefix, *([1]*len(_insert_dims)), *flatten(((o, s) for o,s in zip(o_, s_))))
|
|
if len(_insert_dims):
|
|
xup = xup.expand(*prefix, *_insert_dims, *flatten(((o, s) for o,s in zip(o_, s_))))
|
|
prefix += _insert_dims
|
|
slc_prefix += [(0,x) for x in _insert_dims]
|
|
xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
|
|
return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))])
|
|
|
|
# NOTE: these work for more than 2D
|
|
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
|
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
|
|
|
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
|
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
|
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
|
stride = make_pair(stride, len(HW))
|
|
if any(s>1 for s in stride):
|
|
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
|
|
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
|
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
|
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
|
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
|
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
|
|
|
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
|
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
|
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
|
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
|
|
padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1])
|
|
|
|
# conv2d is a pooling op (with padding)
|
|
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
|
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
|
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
|
|
|
# expand the channels with the pool
|
|
# TODO: this reduces the number of kernels, but it's slower!
|
|
#x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W)
|
|
#rcout, oy, ox = x.shape[2:5]
|
|
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
|
|
|
|
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
|
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
|
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
|
|
|
def dot(self, w:Tensor) -> Tensor:
|
|
n1, n2 = len(self.shape), len(w.shape)
|
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
|
return (x*w).sum(-1)
|
|
|
|
def cumsum(self, axis=0):
|
|
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
|
|
return x.reshape(1, 1, -1, self.shape[axis]).conv2d(Tensor.ones(1, 1, 1, self.shape[axis], dtype=self.dtype, device=self.device), padding=(self.shape[axis]-1, 0, 0, 0)).reshape(*x.shape).permute(*range(axis), self.ndim - 1, *range(axis, self.ndim-1))
|
|
|
|
# ***** mlops (unary) *****
|
|
|
|
def contiguous(self): return mlops.Contiguous.apply(self)
|
|
def log(self): return mlops.Log.apply(self)
|
|
def log2(self): return mlops.Log.apply(self)/log(2)
|
|
def exp(self): return mlops.Exp.apply(self)
|
|
def relu(self): return mlops.Relu.apply(self)
|
|
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
|
def sin(self): return mlops.Sin.apply(self)
|
|
def sqrt(self): return mlops.Sqrt.apply(self)
|
|
def rsqrt(self): return (1/self).sqrt()
|
|
def cos(self): return ((pi/2)-self).sin()
|
|
def tan(self): return self.sin() / self.cos()
|
|
|
|
@staticmethod
|
|
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(c-k, start=-k, **kwargs).unsqueeze(0).expand(r,c)
|
|
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
|
|
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
|
|
|
# ***** math functions (unary) *****
|
|
def ceil(self: Tensor) -> Tensor:
|
|
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
|
return (self > b).where(b+1, b)
|
|
def floor(self: Tensor) -> Tensor:
|
|
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
|
return (self < b).where(b-1, b)
|
|
|
|
def __neg__(self): return 0.0-self
|
|
def square(self): return self*self
|
|
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
|
def abs(self): return self.relu() + (-self).relu()
|
|
def sign(self): return self / (self.abs() + 1e-10)
|
|
def reciprocal(self): return 1.0/self
|
|
|
|
# ***** activation functions (unary) *****
|
|
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
|
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
|
def swish(self): return self * self.sigmoid()
|
|
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
|
|
def relu6(self): return self.relu() - (self-6).relu()
|
|
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
|
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
|
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
|
|
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
|
def quick_gelu(self): return self * (self * 1.702).sigmoid()
|
|
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
|
|
def mish(self): return self * self.softplus().tanh()
|
|
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
|
def softsign(self): return self / (1 + self.abs())
|
|
|
|
# ***** broadcasted binary mlops *****
|
|
|
|
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
|
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
|
|
x: Tensor = self
|
|
y: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
|
|
if reverse: x, y = y, x
|
|
if x.shape == y.shape: return fxn.apply(x, y)
|
|
|
|
len_x_shape, len_y_shape = len(x.shape), len(y.shape)
|
|
max_shape = max(len_x_shape, len_y_shape)
|
|
|
|
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
|
|
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
|
|
|
|
shape_ret = tuple([max(x, y) for x, y in zip(x.shape, y.shape)])
|
|
if x.shape != shape_ret: x = x.expand(shape_ret)
|
|
if y.shape != shape_ret: y = y.expand(shape_ret)
|
|
|
|
return fxn.apply(x, y)
|
|
|
|
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
|
|
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
|
|
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
|
|
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
|
|
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
|
if x.__class__ is not Tensor and not reverse:
|
|
# simple pow identities
|
|
if x < 0: return (1.0/self).pow(-x)
|
|
if x == 2.0: return self*self
|
|
if x == 1.0: return self
|
|
if x == 0.5: return self.sqrt()
|
|
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
|
|
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
|
|
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
|
|
# we only need to correct the sign if the base is negative
|
|
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2
|
|
return ar.mul(sign * base_sign + (1 - base_sign))
|
|
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
|
|
|
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
|
|
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
|
def eq(self, x) -> Tensor: return self._broadcasted(mlops.Equal, x, False)
|
|
|
|
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
|
|
|
|
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
|
|
def __add__(self, x) -> Tensor: return self.add(x)
|
|
def __sub__(self, x) -> Tensor: return self.sub(x)
|
|
def __mul__(self, x) -> Tensor: return self.mul(x)
|
|
def __pow__(self, x) -> Tensor: return self.pow(x)
|
|
def __truediv__(self, x) -> Tensor: return self.div(x)
|
|
def __matmul__(self, x) -> Tensor: return self.matmul(x)
|
|
|
|
def __radd__(self, x) -> Tensor: return self.add(x, True)
|
|
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
|
|
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
|
|
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
|
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
|
|
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
|
|
|
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
|
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
|
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
|
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
|
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
|
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
|
|
|
def __ge__(self, x) -> Tensor: return self.maximum(x).eq(self)
|
|
def __le__(self, x) -> Tensor: return self.maximum(x).eq(x)
|
|
def __lt__(self, x) -> Tensor: return 1.0-(self>=x)
|
|
def __gt__(self, x) -> Tensor: return 1.0-(self<=x)
|
|
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore # mypy things this should be a bool
|
|
def __ne__(self, x) -> Tensor: return 1.0-self.eq(x) # type: ignore
|
|
|
|
# ***** functional nn ops *****
|
|
|
|
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
|
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
|
return x.add(bias) if bias is not None else x
|
|
|
|
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
|
|
|
|
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
|
|
y = (self - self.mean(axis, keepdim=True))
|
|
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
|
|
|
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
|
x = (self - mean.reshape(shape=[1, -1, 1, 1]))
|
|
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
|
|
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
|
|
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
|
|
|
|
def dropout(self, p=0.5) -> Tensor:
|
|
if not Tensor.training: return self
|
|
mask = (Tensor.rand(*self.shape, requires_grad=False) >= p).cast(dtypes.bool)
|
|
return self * mask * (1/(1.0 - p))
|
|
|
|
# ***** cast ops *****
|
|
|
|
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
|
|
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
|
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
|
|
|
# ***** Convenience stuff *****
|
|
@property
|
|
def ndim(self) -> int: return len(self.shape)
|
|
def numel(self) -> int: return prod(self.shape)
|
|
def element_size(self) -> int: return self.dtype.itemsize
|
|
def nbytes(self) -> int: return self.numel() * self.element_size()
|
|
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
|
|
|
# register functions to move between devices
|
|
for device in Device._buffers:
|
|
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
|
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
|
|
|
|
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
|
from tinygrad.nn.image import image_conv2d, image_dot
|
|
if IMAGE:
|
|
setattr(Tensor, "conv2d", image_conv2d)
|
|
setattr(Tensor, "dot", image_dot)
|