improve typing
parent
2e1bdc889a
commit
45ce4de6f3
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from typing import Final
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
|
||||
|
||||
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
|
@ -10,7 +10,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
|||
})
|
||||
|
||||
class CPUBuffer(GenericBufExecAST):
|
||||
fxn_for_op : Final = specialized_fxn_for_op
|
||||
fxn_for_op : ClassVar = specialized_fxn_for_op
|
||||
def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import hashlib
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple, Union, Dict, Any, List
|
||||
from typing import Tuple, Union, Dict, Any, List, ClassVar
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.shape import ShapeTracker, ZeroView
|
||||
from tinygrad.ops import LazyOp
|
||||
|
@ -68,9 +68,9 @@ def idx_deref(builder, buf, ptr, idx):
|
|||
return builder.load(builder.gep(ptr, [idx], inbounds=True))
|
||||
|
||||
class LLVM:
|
||||
target_machine = None
|
||||
engine = None
|
||||
optimizer = None
|
||||
target_machine : ClassVar[llvm.targets.TargetMachine] = None
|
||||
engine : ClassVar[llvm.executionengine.ExecutionEngine] = None
|
||||
optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None
|
||||
|
||||
def __init__(self):
|
||||
if LLVM.engine is not None:
|
||||
|
@ -104,7 +104,7 @@ class LLVM:
|
|||
backing_mod.triple = llvm.get_process_triple()
|
||||
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
|
||||
|
||||
def exec(self, module, bufs, op_estimate=0, mem_estimate=0):
|
||||
def exec(self, module:ir.Module, bufs, op_estimate=0, mem_estimate=0):
|
||||
module.triple = llvm.get_process_triple()
|
||||
module.data_layout = self.engine.target_data
|
||||
llvm_ir = str(module)
|
||||
|
@ -146,7 +146,7 @@ class LLVM:
|
|||
|
||||
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
|
||||
class LLVMBuffer(ExplicitExecAST):
|
||||
op_lookup = {
|
||||
op_lookup : ClassVar = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
|
||||
|
@ -161,7 +161,7 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType())
|
||||
}
|
||||
start_for_op = {
|
||||
start_for_op : ClassVar = {
|
||||
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
|
||||
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from typing import Final
|
||||
from typing import ClassVar
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
@ -11,7 +11,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
|||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(GenericBufExecAST):
|
||||
fxn_for_op : Final = specialized_fxn_for_op
|
||||
fxn_for_op : ClassVar = specialized_fxn_for_op
|
||||
def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -67,6 +67,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
|||
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method
|
||||
fxn_for_op : ClassVar
|
||||
# TODO: use generic types here to remove __init__ in specialized classes
|
||||
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
|
||||
def contiguous(self): return self.unary_op(UnaryOps.NOOP)
|
||||
|
|
Loading…
Reference in New Issue