refactor ops_cpu and ops_torch to not share code
parent
ee18420c13
commit
1029deccb1
|
@ -1,5 +1,5 @@
|
|||
#!/bin/bash
|
||||
mypyc tinygrad/llops/ops_gpu.py tinygrad/shape/__init__.py tinygrad/ops.py tinygrad/ast.py \
|
||||
tinygrad/helpers.py tinygrad/mlops.py tinygrad/nn/__init__.py tinygrad/graph.py tinygrad/lazy.py \
|
||||
tinygrad/tensor.py
|
||||
tinygrad/tensor.py tinygrad/llops/ops_cpu.py tinygrad/llops/ops_torch.py
|
||||
|
||||
|
|
|
@ -1,36 +1,22 @@
|
|||
import operator
|
||||
import numpy as np
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, GenericExecAST
|
||||
from tinygrad.helpers import shape_to_axis
|
||||
from typing import Final
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
|
||||
|
||||
base_fxn_for_op = {
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
}
|
||||
|
||||
class CPUBuffer(GenericExecAST):
|
||||
fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
|
||||
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
|
||||
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
|
||||
MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg])
|
||||
})
|
||||
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
|
||||
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
|
||||
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
|
||||
MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg])
|
||||
})
|
||||
|
||||
class CPUBuffer(GenericBufExecAST):
|
||||
fxn_for_op : Final = specialized_fxn_for_op
|
||||
def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return CPUBuffer(x)
|
||||
def toCPU(x): return x.buf
|
||||
|
||||
def contiguous(x): return x.unary_op(UnaryOps.NOOP)
|
||||
def unary_op(x, op): return type(x)(x.fxn_for_op[op](x.buf))
|
||||
def binary_op(x, op, y): return type(x)(x.fxn_for_op[op](x.buf, y.buf))
|
||||
def reduce_op(x, op, new_shape): return type(x)(x.fxn_for_op[op](x.buf, new_shape))
|
||||
def movement_op(x, op, arg=None): return type(x)(x.fxn_for_op[op](x.buf, arg)) if op in x.fxn_for_op else type(x)(getattr(x.buf, op.name.lower())(arg))
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
assert C.px == 0 and C.px_ == 0 and C.py == 0 and C.py_ == 0, "padding in conv is not supported"
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
import torch
|
||||
from tinygrad.llops.ops_cpu import base_fxn_for_op, CPUBuffer # type: ignore
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST
|
||||
from typing import Final
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(GenericExecAST):
|
||||
fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
|
||||
})
|
||||
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
|
||||
})
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(GenericBufExecAST):
|
||||
fxn_for_op : Final = specialized_fxn_for_op
|
||||
def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
|
||||
def toCPU(x): return x.buf.cpu().numpy()
|
||||
|
||||
contiguous, unary_op, binary_op, reduce_op, movement_op = CPUBuffer.contiguous, CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op
|
||||
|
||||
SUPPORTS_SIMPLE_PADDING = True
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from enum import Enum, auto
|
||||
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar
|
||||
import functools, operator
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, shape_to_axis
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
@ -45,7 +45,6 @@ class DeviceBuffer:
|
|||
def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented")
|
||||
|
||||
# extend this if you don't have an exec_ast function
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x):
|
||||
|
@ -66,6 +65,22 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
|||
raise TypeError("unknown op")
|
||||
return ret
|
||||
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method
|
||||
def contiguous(self): return self.unary_op(UnaryOps.NOOP)
|
||||
def unary_op(self, op): return type(self)(self.fxn_for_op[op](self.buf))
|
||||
def binary_op(self, op, y): return type(self)(self.fxn_for_op[op](self.buf, y.buf))
|
||||
def reduce_op(self, op, new_shape): return type(self)(self.fxn_for_op[op](self.buf, new_shape))
|
||||
def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
|
||||
|
||||
base_fxn_for_op = {
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
MovementOps.SHRINK: lambda x, arg: x[tuple(slice(p[0], p[1], None) for p in arg)],
|
||||
}
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops : ClassVar[int] = 0
|
||||
global_mem : ClassVar[int] = 0
|
||||
|
|
Loading…
Reference in New Issue