From 3c8da6bd031b71bb1c78c6f8ecf7735d2c4605ea Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Feb 2023 10:54:46 -0800 Subject: [PATCH] add typing --- .github/workflows/test.yml | 2 +- .pre-commit-config.yaml | 2 +- compile.sh | 3 +- tinygrad/helpers.py | 4 +-- tinygrad/jit.py | 5 +-- tinygrad/lazy.py | 2 +- tinygrad/nn/__init__.py | 6 ++-- tinygrad/tensor.py | 72 +++++++++++++++++++------------------- 8 files changed, 49 insertions(+), 47 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 09b9306f9..107019254 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,7 +40,7 @@ jobs: - name: Lint tinygrad with pylint run: pylint tinygrad/ - name: Run mypy - run: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs + run: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs --warn-unreachable testcpu: name: CPU Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b35accc7d..7871e4593 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: pass_filenames: false - id: mypy name: mypy - entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs + entry: mypy tinygrad/ --check-untyped-defs --explicit-package-bases --warn-unreachable # --warn-return-any language: system always_run: true pass_filenames: false diff --git a/compile.sh b/compile.sh index 875b5f8c6..0ac471ed1 100755 --- a/compile.sh +++ b/compile.sh @@ -1,5 +1,6 @@ #!/bin/bash -mypyc tinygrad/llops/ops_gpu.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/ast.py \ +mypyc --explicit-package-bases \ + tinygrad/llops/ops_gpu.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/ast.py \ tinygrad/helpers.py tinygrad/mlops.py tinygrad/nn/__init__.py tinygrad/graph.py tinygrad/lazy.py \ tinygrad/tensor.py tinygrad/llops/ops_cpu.py tinygrad/llops/ops_torch.py tinygrad/nn/optim.py diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 3ed45f417..33f812567 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,9 +1,9 @@ from collections import namedtuple import os, math, functools, time -from typing import Tuple, Union +from typing import Tuple, Union, List def dedup(x): return list(dict.fromkeys(x)) # retains list order -def prod(x) -> int: return math.prod(x) +def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x) def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x) def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 57c83d375..37eedf9d4 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Any, Dict +from typing import Callable, List, Tuple, Any, Dict, cast import itertools from tinygrad.lazy import Device from tinygrad.tensor import Tensor @@ -14,7 +14,8 @@ class TinyJit: def __call__(self, *args, **kwargs) -> Any: if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU - input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} + # NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't + input_tensors = {k:cast(DeviceBuffer, v.realize().lazydata.realized)._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} assert len(input_tensors) != 0, "no inputs to JIT" if self.cnt >= 2: for a,idx in self.input_replace.items(): a._buf = input_tensors[idx] diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a7c0eb2cb..9a6dcb75d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -169,7 +169,7 @@ class LazyBuffer: return self.realized @staticmethod - def fromCPU(x, device): return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy())) + def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy())) def toCPU(self): return self.realize().toCPU() def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 6d203da21..31b919304 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union, Tuple from tinygrad.tensor import Tensor class BatchNorm2d: @@ -59,7 +59,7 @@ class Linear: return x.linear(self.weight.transpose(), self.bias) class GroupNorm: - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True): self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None @@ -74,7 +74,7 @@ class GroupNorm: return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1) class LayerNorm: - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape) self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d9ddc1ffd..94371bec8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -140,11 +140,11 @@ class Tensor: def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed) @staticmethod - def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs) + def rand(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs) # TODO: replace with a transformation from uniform -> gaussian @staticmethod - def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) + def randn(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs) # ***** rng hlops ***** @@ -152,11 +152,11 @@ class Tensor: def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1 @staticmethod - def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5) + def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5) # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod - def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5) + def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5) # ***** toposort and backward pass ***** @@ -191,11 +191,11 @@ class Tensor: # ***** movement mlops ***** - def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args)) - def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) - def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args)) - def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args)) - def slice(self, arg): return mlops.Slice.apply(self, arg=arg) + def reshape(self, shape, *args) -> Tensor: return mlops.Reshape.apply(self, shape=argfix(shape, *args)) + def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) + def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) + def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args)) + def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=arg) # ***** movement hlops ***** @@ -250,7 +250,7 @@ class Tensor: # (padding_left, padding_right, padding_top, padding_bottom) def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1]))) # TODO: this is totally not transpose - def transpose(self, order=(1,0)): return self.permute(order=order) + def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order) def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) # ***** reduce ops ***** @@ -331,7 +331,7 @@ class Tensor: ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) - def dot(self:Tensor, w:Tensor): + def dot(self, w:Tensor) -> Tensor: # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] @@ -389,39 +389,39 @@ class Tensor: shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape)) return fxn.apply(x.expand(shape_ret), y.expand(shape_ret)) - def add(self, x, reverse=False): return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self - def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self - def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self - def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self - def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self - def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x) + def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self + def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self + def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self + def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self + def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self + def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) - def maximum(self, x): return self._broadcasted(mlops.Maximum, x) + def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x) def minimum(self, x): return -((-self).maximum(-x)) # ***** binary op wrappers (18 wasted lines to make the typechecker happy) ***** # NOTE: __pow__ and friends are broken in mypyc with the ** operator - def __add__(self, x): return self.add(x) - def __sub__(self, x): return self.sub(x) - def __mul__(self, x): return self.mul(x) - def __pow__(self, x): return self.pow(x) - def __truediv__(self, x): return self.div(x) - def __matmul__(self, x): return self.matmul(x) + def __add__(self, x) -> Tensor: return self.add(x) + def __sub__(self, x) -> Tensor: return self.sub(x) + def __mul__(self, x) -> Tensor: return self.mul(x) + def __pow__(self, x) -> Tensor: return self.pow(x) + def __truediv__(self, x) -> Tensor: return self.div(x) + def __matmul__(self, x) -> Tensor: return self.matmul(x) - def __radd__(self, x): return self.add(x, True) - def __rsub__(self, x): return self.sub(x, True) - def __rmul__(self, x): return self.mul(x, True) - def __rpow__(self, x): return self.pow(x, True) - def __rtruediv__(self, x): return self.div(x, True) - def __rmatmul__(self, x): return self.matmul(x, True) + def __radd__(self, x) -> Tensor: return self.add(x, True) + def __rsub__(self, x) -> Tensor: return self.sub(x, True) + def __rmul__(self, x) -> Tensor: return self.mul(x, True) + def __rpow__(self, x) -> Tensor: return self.pow(x, True) + def __rtruediv__(self, x) -> Tensor: return self.div(x, True) + def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True) - def __iadd__(self, x): return self.assign(self.add(x)) - def __isub__(self, x): return self.assign(self.sub(x)) - def __imul__(self, x): return self.assign(self.mul(x)) - def __ipow__(self, x): return self.assign(self.pow(x)) - def __itruediv__(self, x): return self.assign(self.div(x)) - def __imatmul__(self, x): return self.assign(self.matmul(x)) + def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) + def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) + def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) + def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x)) + def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) + def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) # ***** functional nn ops *****