cleanups before interpreted jit (#2306)
* jit mnist * InterpretedFlopCounter doesn't rely on Interpreted * allocator for cpu and torch * types for exec_ast * fix type issues * fix onnx, remove print * always self.from_underlyingpull/1971/head
parent
91546225f4
commit
4f7b1ac0d2
|
@ -2,31 +2,36 @@ import numpy as np
|
|||
from tqdm import trange
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
|
||||
@TinyJit
|
||||
def train_step(x, y):
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if noloss: del loss
|
||||
optim.step()
|
||||
if noloss: return (None, None)
|
||||
cat = out.argmax(axis=-1)
|
||||
accuracy = (cat == y).mean()
|
||||
return loss.realize(), accuracy.realize()
|
||||
|
||||
with Tensor.train():
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=CI)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
x = Tensor(transform(X_train[samp]), requires_grad=False)
|
||||
y = Tensor(target_transform(Y_train[samp]))
|
||||
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if noloss: del loss
|
||||
optim.step()
|
||||
|
||||
loss, accuracy = train_step(x, y)
|
||||
# printing
|
||||
if not noloss:
|
||||
cat = out.argmax(axis=-1)
|
||||
accuracy = (cat == y).mean().numpy()
|
||||
|
||||
loss = loss.detach().numpy()
|
||||
loss, accuracy = loss.numpy(), accuracy.numpy()
|
||||
losses.append(loss)
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
|
|
@ -109,20 +109,20 @@ class dtypes:
|
|||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
float16: Final[DType] = DType(8, 2, "half", np.float16)
|
||||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float32: Final[DType] = DType(9, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(0, 8, "double", np.float64)
|
||||
float64: Final[DType] = DType(10, 8, "double", np.float64)
|
||||
double = float64
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
||||
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
||||
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
||||
int16: Final[DType] = DType(2, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(4, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(6, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(1, 1, "unsigned char", np.uint8)
|
||||
uint16: Final[DType] = DType(3, 2, "unsigned short", np.uint16)
|
||||
uint32: Final[DType] = DType(5, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(7, 8, "unsigned long", np.uint64)
|
||||
|
||||
# NOTE: bfloat16 isn't supported in numpy
|
||||
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
||||
|
|
|
@ -159,13 +159,16 @@ class Interpreted:
|
|||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
f = self.fxn_for_op
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if self != InterpretedFlopCounter and ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
|
@ -190,23 +193,23 @@ class Interpreted:
|
|||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
|
||||
if DEBUG >= 4 and self != InterpretedFlopCounter: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
|
||||
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
|
||||
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
# TODO: is this used?
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert ret.dtype == output.dtype, f"{ret.dtype} != {output.dtype}"
|
||||
if output.output_buffer is not None:
|
||||
assert output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
return ret
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
class FlopCounter:
|
||||
shape: Tuple[int, ...]
|
||||
|
@ -218,16 +221,20 @@ class FlopCounter:
|
|||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
InterpretedFlopCounter = Interpreted(FlopCounter, {
|
||||
|
||||
InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.MEM: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.size()}), BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: FlopCounter(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})})
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})}
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.exec_ast(ast)
|
||||
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
|
||||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
|
@ -319,7 +326,7 @@ class Compiled:
|
|||
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
|
||||
return prg
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
|
|
|
@ -27,7 +27,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
|
||||
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
|
||||
|
||||
# *** zero op LoadOps ***
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ class RawBuffer: # pylint: disable=abstract-method
|
|||
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
||||
self._buf = buf if buf is not None else (allocator(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
self._allocator = allocator
|
||||
self._allocator = allocator if allocator and hasattr(allocator, 'free') else None
|
||||
self._device = kwargs.get('device', None)
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
|
||||
|
@ -100,7 +100,7 @@ class LRUAllocator:
|
|||
self.buffer_info.pop(buf_to_free)
|
||||
self._do_free(buf_to_free)
|
||||
|
||||
def alloc(self, size, dtype, device='0', **kwargs):
|
||||
def __call__(self, size, dtype, device='0', **kwargs): # allocate
|
||||
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
|
||||
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
|
||||
|
||||
|
|
|
@ -11,13 +11,15 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
|||
|
||||
base_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.MEM: lambda x: x._buf, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.RESHAPE: lambda x, arg: x.reshape(arg), MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
}
|
||||
|
||||
# TODO: this should be global infrastructure
|
||||
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
def match_types(x, y):
|
||||
up = x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
up = output_type(x, y)
|
||||
return x.astype(up, copy=False), y.astype(up, copy=False)
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
|
@ -34,9 +36,9 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False), UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt,
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)],
|
||||
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])),
|
||||
|
@ -45,7 +47,7 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
}}
|
||||
|
||||
class RawNumpyBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None): super().__init__(size, dtype, buf if buf is not None else np.empty([size], dtype.np))
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[np.ndarray]=None, allocator=lambda size, dtype: np.empty([size], dtype.np)): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
|
|
|
@ -10,6 +10,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if geten
|
|||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
|
||||
def match_types(x, y, disallow_bool=False):
|
||||
up = output_type(x, y)
|
||||
if disallow_bool and up == torch.bool: up = torch.float
|
||||
return x.type(up), y.type(up)
|
||||
|
||||
def as_strided(x, arg):
|
||||
if any(i < 0 for i in arg[1]):
|
||||
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
|
||||
|
@ -22,9 +28,13 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=dtype.np)).requires_grad_(False).to(device),
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(next(k for k,v in type_map.items() if v==y[0])), UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)), BinaryOps.SUB: lambda x,y: torch.logical_xor(x, y) if y.dtype is torch.bool else torch.sub(x, y),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.ADD: lambda x,y: torch.add(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(*match_types(x, y, disallow_bool=True)).type(output_type(x,y)),
|
||||
BinaryOps.MUL: lambda x,y: torch.mul(*match_types(x, y)).type(output_type(x,y)),
|
||||
BinaryOps.DIV: lambda x,y: torch.div(*match_types(x, y)).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(torch.promote_types(a.dtype, b.dtype)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)), lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
|
||||
MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]),
|
||||
MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg),
|
||||
|
@ -32,7 +42,7 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
|||
}}
|
||||
|
||||
class RawTorchBuffer(RawBuffer):
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None): super().__init__(size, dtype, buf if buf is not None else torch.empty([size], device=device, dtype=inverse_type_map[dtype]))
|
||||
def __init__(self, size:int, dtype:DType, buf:Optional[torch.Tensor]=None, allocator=lambda size, dtype: torch.empty([size], device=device, dtype=inverse_type_map[dtype])): super().__init__(size, dtype, buf, allocator)
|
||||
@classmethod
|
||||
def fromCPU(cls, x):
|
||||
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
|
||||
|
|
Loading…
Reference in New Issue