1
0
Fork 0

add typing

pull/618/head
George Hotz 2023-02-28 10:54:46 -08:00
parent 922f96e527
commit 3c8da6bd03
8 changed files with 49 additions and 47 deletions

View File

@ -40,7 +40,7 @@ jobs:
- name: Lint tinygrad with pylint
run: pylint tinygrad/
- name: Run mypy
run: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
run: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs --warn-unreachable
testcpu:
name: CPU Tests

View File

@ -15,7 +15,7 @@ repos:
pass_filenames: false
- id: mypy
name: mypy
entry: mypy tinygrad/ --ignore-missing-imports --check-untyped-defs
entry: mypy tinygrad/ --check-untyped-defs --explicit-package-bases --warn-unreachable # --warn-return-any
language: system
always_run: true
pass_filenames: false

View File

@ -1,5 +1,6 @@
#!/bin/bash
mypyc tinygrad/llops/ops_gpu.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/ast.py \
mypyc --explicit-package-bases \
tinygrad/llops/ops_gpu.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/ast.py \
tinygrad/helpers.py tinygrad/mlops.py tinygrad/nn/__init__.py tinygrad/graph.py tinygrad/lazy.py \
tinygrad/tensor.py tinygrad/llops/ops_cpu.py tinygrad/llops/ops_torch.py tinygrad/nn/optim.py

View File

@ -1,9 +1,9 @@
from collections import namedtuple
import os, math, functools, time
from typing import Tuple, Union
from typing import Tuple, Union, List
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def prod(x) -> int: return math.prod(x)
def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True

View File

@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Any, Dict
from typing import Callable, List, Tuple, Any, Dict, cast
import itertools
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
@ -14,7 +14,8 @@ class TinyJit:
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU
input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_tensors = {k:cast(DeviceBuffer, v.realize().lazydata.realized)._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_tensors) != 0, "no inputs to JIT"
if self.cnt >= 2:
for a,idx in self.input_replace.items(): a._buf = input_tensors[idx]

View File

@ -169,7 +169,7 @@ class LazyBuffer:
return self.realized
@staticmethod
def fromCPU(x, device): return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()))
def fromCPU(x, device) -> LazyBuffer: return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()))
def toCPU(self): return self.realize().toCPU()
def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self)

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
class BatchNorm2d:
@ -59,7 +59,7 @@ class Linear:
return x.linear(self.weight.transpose(), self.bias)
class GroupNorm:
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
@ -74,7 +74,7 @@ class GroupNorm:
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
class LayerNorm:
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)

View File

@ -140,11 +140,11 @@ class Tensor:
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
@staticmethod
def rand(*shape, **kwargs): return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs)
def rand(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.random(size=shape, dtype=np.float32), **kwargs)
# TODO: replace with a transformation from uniform -> gaussian
@staticmethod
def randn(*shape, **kwargs): return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
def randn(*shape, **kwargs) -> Tensor: return Tensor(Tensor._rng.standard_normal(size=shape, dtype=np.float32), **kwargs)
# ***** rng hlops *****
@ -152,11 +152,11 @@ class Tensor:
def uniform(*shape, **kwargs) -> Tensor: return Tensor.rand(*shape, **kwargs) * 2 - 1
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * (prod(shape)**-0.5)
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs) * ((6/(shape[0]+prod(shape[1:])))**0.5)
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
# ***** toposort and backward pass *****
@ -191,11 +191,11 @@ class Tensor:
# ***** movement mlops *****
def reshape(self, shape, *args): return mlops.Reshape.apply(self, shape=argfix(shape, *args))
def expand(self, shape, *args): return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
def reshape(self, shape, *args) -> Tensor: return mlops.Reshape.apply(self, shape=argfix(shape, *args))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=arg)
# ***** movement hlops *****
@ -250,7 +250,7 @@ class Tensor:
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
# TODO: this is totally not transpose
def transpose(self, order=(1,0)): return self.permute(order=order)
def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
# ***** reduce ops *****
@ -331,7 +331,7 @@ class Tensor:
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
def dot(self:Tensor, w:Tensor):
def dot(self, w:Tensor) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
@ -389,39 +389,39 @@ class Tensor:
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
def add(self, x, reverse=False): return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x)
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x): return self._broadcasted(mlops.Maximum, x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
def minimum(self, x): return -((-self).maximum(-x))
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
# NOTE: __pow__ and friends are broken in mypyc with the ** operator
def __add__(self, x): return self.add(x)
def __sub__(self, x): return self.sub(x)
def __mul__(self, x): return self.mul(x)
def __pow__(self, x): return self.pow(x)
def __truediv__(self, x): return self.div(x)
def __matmul__(self, x): return self.matmul(x)
def __add__(self, x) -> Tensor: return self.add(x)
def __sub__(self, x) -> Tensor: return self.sub(x)
def __mul__(self, x) -> Tensor: return self.mul(x)
def __pow__(self, x) -> Tensor: return self.pow(x)
def __truediv__(self, x) -> Tensor: return self.div(x)
def __matmul__(self, x) -> Tensor: return self.matmul(x)
def __radd__(self, x): return self.add(x, True)
def __rsub__(self, x): return self.sub(x, True)
def __rmul__(self, x): return self.mul(x, True)
def __rpow__(self, x): return self.pow(x, True)
def __rtruediv__(self, x): return self.div(x, True)
def __rmatmul__(self, x): return self.matmul(x, True)
def __radd__(self, x) -> Tensor: return self.add(x, True)
def __rsub__(self, x) -> Tensor: return self.sub(x, True)
def __rmul__(self, x) -> Tensor: return self.mul(x, True)
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor: return self.div(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __iadd__(self, x): return self.assign(self.add(x))
def __isub__(self, x): return self.assign(self.sub(x))
def __imul__(self, x): return self.assign(self.mul(x))
def __ipow__(self, x): return self.assign(self.pow(x))
def __itruediv__(self, x): return self.assign(self.div(x))
def __imatmul__(self, x): return self.assign(self.matmul(x))
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
# ***** functional nn ops *****