1
0
Fork 0

improve typing

pull/545/head
George Hotz 2023-02-08 12:48:21 -06:00
parent 2e1bdc889a
commit 45ce4de6f3
4 changed files with 12 additions and 11 deletions

View File

@ -1,5 +1,5 @@
import numpy as np
from typing import Final
from typing import ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
@ -10,7 +10,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
})
class CPUBuffer(GenericBufExecAST):
fxn_for_op : Final = specialized_fxn_for_op
fxn_for_op : ClassVar = specialized_fxn_for_op
def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import hashlib
import math
import time
from typing import Tuple, Union, Dict, Any, List
from typing import Tuple, Union, Dict, Any, List, ClassVar
from tinygrad.helpers import prod, getenv
from tinygrad.shape import ShapeTracker, ZeroView
from tinygrad.ops import LazyOp
@ -68,9 +68,9 @@ def idx_deref(builder, buf, ptr, idx):
return builder.load(builder.gep(ptr, [idx], inbounds=True))
class LLVM:
target_machine = None
engine = None
optimizer = None
target_machine : ClassVar[llvm.targets.TargetMachine] = None
engine : ClassVar[llvm.executionengine.ExecutionEngine] = None
optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None
def __init__(self):
if LLVM.engine is not None:
@ -104,7 +104,7 @@ class LLVM:
backing_mod.triple = llvm.get_process_triple()
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
def exec(self, module, bufs, op_estimate=0, mem_estimate=0):
def exec(self, module:ir.Module, bufs, op_estimate=0, mem_estimate=0):
module.triple = llvm.get_process_triple()
module.data_layout = self.engine.target_data
llvm_ir = str(module)
@ -146,7 +146,7 @@ class LLVM:
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
class LLVMBuffer(ExplicitExecAST):
op_lookup = {
op_lookup : ClassVar = {
UnaryOps.NOOP: lambda builder,x: x,
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
@ -161,7 +161,7 @@ class LLVMBuffer(ExplicitExecAST):
BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)),
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType())
}
start_for_op = {
start_for_op : ClassVar = {
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
}

View File

@ -1,5 +1,5 @@
import torch
from typing import Final
from typing import ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
from tinygrad.helpers import getenv
@ -11,7 +11,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(GenericBufExecAST):
fxn_for_op : Final = specialized_fxn_for_op
fxn_for_op : ClassVar = specialized_fxn_for_op
def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod

View File

@ -67,6 +67,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
# used in CPUBuffer and TorchBuffer
class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method
fxn_for_op : ClassVar
# TODO: use generic types here to remove __init__ in specialized classes
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
def contiguous(self): return self.unary_op(UnaryOps.NOOP)