1
0
Fork 0

cleanups before interpreted jit (#2306)

* jit mnist

* InterpretedFlopCounter doesn't rely on Interpreted

* allocator for cpu and torch

* types for exec_ast

* fix type issues

* fix onnx, remove print

* always self.from_underlying
pull/1971/head
George Hotz 2023-11-14 21:44:25 -08:00 committed by GitHub
parent 91546225f4
commit 4f7b1ac0d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 72 additions and 48 deletions

View File

@ -2,31 +2,36 @@ import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad.jit import TinyJit
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
@TinyJit
def train_step(x, y):
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss: del loss
optim.step()
if noloss: return (None, None)
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean()
return loss.realize(), accuracy.realize()
with Tensor.train():
losses, accuracies = [], []
for i in (t := trange(steps, disable=CI)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp]))
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
if noloss: del loss
optim.step()
loss, accuracy = train_step(x, y)
# printing
if not noloss:
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean().numpy()
loss = loss.detach().numpy()
loss, accuracy = loss.numpy(), accuracy.numpy()
losses.append(loss)
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))

View File

@ -109,20 +109,20 @@ class dtypes:
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(0, 2, "half", np.float16)
float16: Final[DType] = DType(8, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(4, 4, "float", np.float32)
float32: Final[DType] = DType(9, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(0, 8, "double", np.float64)
float64: Final[DType] = DType(10, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(0, 1, "char", np.int8)
int16: Final[DType] = DType(1, 2, "short", np.int16)
int32: Final[DType] = DType(2, 4, "int", np.int32)
int64: Final[DType] = DType(3, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
int16: Final[DType] = DType(2, 2, "short", np.int16)
int32: Final[DType] = DType(4, 4, "int", np.int32)
int64: Final[DType] = DType(6, 8, "long", np.int64)
uint8: Final[DType] = DType(1, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(3, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(5, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(7, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)

View File

@ -159,13 +159,16 @@ class Interpreted:
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if self != InterpretedFlopCounter and ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
@ -190,23 +193,23 @@ class Interpreted:
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
if DEBUG >= 4 and self != InterpretedFlopCounter: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used?
if output is not None and output.output_buffer is not None:
assert ret.dtype == output.dtype, f"{ret.dtype} != {output.dtype}"
if output.output_buffer is not None:
assert output.output_buffer.dtype == ret.dtype
output.output_buffer._buf = ret._buf
return output.output_buffer
return ret
# **************** independent FlopCounter ****************
@dataclass
class FlopCounter:
shape: Tuple[int, ...]
@ -218,16 +221,20 @@ class FlopCounter:
def consume_flops(self):
self.flops, ret = 0, self.flops
return ret
InterpretedFlopCounter = Interpreted(FlopCounter, {
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.MEM: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})})
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})}
@functools.lru_cache(None)
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** for Compiled Buffers ****************
@ -319,7 +326,7 @@ class Compiled:
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?

View File

@ -27,7 +27,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
# *** zero op LoadOps ***

View File

@ -9,9 +9,9 @@ class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator
self._allocator = allocator if allocator and hasattr(allocator, 'free') else None
self._device = kwargs.get('device', None)
GlobalCounters.mem_used += self._memsz
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
@ -100,7 +100,7 @@ class LRUAllocator:
self.buffer_info.pop(buf_to_free)
self._do_free(buf_to_free)
def alloc(self, size, dtype, device='0', **kwargs):
def __call__(self, size, dtype, device='0', **kwargs): # allocate
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)

View File

@ -11,13 +11,15 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
base_fxn_for_op: Dict[Op, Callable] = {
BufferOps.MEM: lambda x: x._buf, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
}
# TODO: this should be global infrastructure
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
def match_types(x, y):
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
up = output_type(x, y)
return x.astype(up, copy=False), y.astype(up, copy=False)
def einsum_mulacc(einsum, get_strides, expand):
@ -34,9 +36,9 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
@ -45,7 +47,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
}}
class RawNumpyBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf

View File

@ -10,6 +10,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
inverse_type_map = {v:k for k,v in type_map.items()}
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
def match_types(x, y, disallow_bool=False):
up = output_type(x, y)
if disallow_bool and up == torch.bool: up = torch.float
return x.type(up), y.type(up)
def as_strided(x, arg):
if any(i < 0 for i in arg[1]):
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
@ -22,9 +28,13 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)), BinaryOps.SUB: lambda x,y: torch.logical_xor(x, y) if y.dtype is torch.bool else torch.sub(x, y),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
BinaryOps.ADD: lambda x,y: torch.add(*match_types(x, y)).type(output_type(x,y)),
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
BinaryOps.MUL: lambda x,y: torch.mul(*match_types(x, y)).type(output_type(x,y)),
BinaryOps.DIV: lambda x,y: torch.div(*match_types(x, y)).type(torch.promote_types(x.dtype, y.dtype)),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)), lambda x: x.stride(), lambda x,s: x.expand(s)),
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg),
@ -32,7 +42,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
}}
class RawTorchBuffer(RawBuffer):
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf if buf is not None else torch.empty([size], device=device, dtype=inverse_type_map[dtype]))
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
@classmethod
def fromCPU(cls, x):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)