op logger + replay (#2021)
* logops * fix dtype printing * needs inf * ops dataset * minor improvements * 12k kernels * opt can compile * graph flopspull/2024/head
parent
46f354b49f
commit
16ca8410f8
|
@ -40,3 +40,4 @@ temp
|
|||
.coverage
|
||||
coverage.xml
|
||||
htmlcov
|
||||
outputs_yolov8
|
||||
|
|
|
@ -29,14 +29,14 @@ def spec_unet3d():
|
|||
# 3D UNET
|
||||
from models.unet3d import UNet3D
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
#mdl.load_from_pretrained()
|
||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||
test_model(mdl, img)
|
||||
|
||||
def spec_rnnt():
|
||||
from models.rnnt import RNNT
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(220, 1, 240)
|
||||
y = Tensor.randn(1, 220)
|
||||
test_model(mdl, x, y)
|
||||
|
@ -44,7 +44,7 @@ def spec_rnnt():
|
|||
def spec_bert():
|
||||
from models.bert import BertForQuestionAnswering
|
||||
mdl = BertForQuestionAnswering()
|
||||
mdl.load_from_pretrained()
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(1, 384)
|
||||
am = Tensor.randn(1, 384)
|
||||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||
|
@ -53,7 +53,7 @@ def spec_bert():
|
|||
def spec_mrcnn():
|
||||
from models.mask_rcnn import MaskRCNN, ResNet
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(3, 224, 224)
|
||||
test_model(mdl, [x])
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
export LOGOPS=/tmp/ops
|
||||
rm $LOGOPS
|
||||
|
||||
# generate many kernels
|
||||
PYTHONPATH="." OPT=2 GPU=1 python3 test/external/external_test_opt.py
|
||||
PYTHONPATH="." OPT=3 GPU=1 python3 test/external/external_test_opt.py
|
||||
GPU=1 IMAGE=1 python3 test/test_ops.py
|
||||
FORWARD_ONLY=1 GPU=1 IMAGE=2 python test/test_ops.py
|
||||
STEPS=3 python3 examples/hlb_cifar10.py
|
||||
WINO=1 STEPS=3 python3 examples/hlb_cifar10.py
|
||||
python3 examples/stable_diffusion.py --noshow
|
||||
python3 examples/llama.py --prompt "hello" --count 5
|
||||
python3 examples/gpt2.py --count 5
|
||||
python3 examples/mlperf/model_spec.py
|
||||
python3 examples/yolov8.py ./test/models/efficientnet/Chicken.jpg
|
||||
openpilot/go.sh
|
||||
BIG=1 MPS=1 pytest test/
|
||||
|
||||
# sort and uniq
|
||||
sort -u /tmp/ops > /tmp/sops
|
||||
ls -lh /tmp/ops /tmp/sops
|
|
@ -0,0 +1,87 @@
|
|||
import sys
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
from tinygrad.helpers import dedup, ImageDType, getenv
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.lazy import var_vals_from_ast
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = dedup(open(sys.argv[1]).read().strip().split("\n"))
|
||||
|
||||
# reduce kernels only
|
||||
ast_strs = [x for x in ast_strs if "ReduceOps" in x]
|
||||
|
||||
# the device we are optimizing for
|
||||
device: Compiled = Device[Device.DEFAULT]
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
# random first kernels
|
||||
random.seed(1337)
|
||||
random.shuffle(ast_strs)
|
||||
ast_strs = ast_strs[:1000]
|
||||
|
||||
print(f"loaded {len(ast_strs)} kernels")
|
||||
|
||||
atm = []
|
||||
agflops = []
|
||||
for ast_str in tqdm(ast_strs):
|
||||
ast = eval(ast_str)
|
||||
lin = Linearizer(ast)
|
||||
|
||||
# skip image textures
|
||||
if any(isinstance(x.dtype, ImageDType) for x in lin.bufs): continue
|
||||
|
||||
# create output/input buffers
|
||||
bufsts = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if isinstance(x, MemBuffer):
|
||||
bufsts[x.idx].append(x)
|
||||
buffer_count = len(bufsts)
|
||||
rawbufs = [None]*buffer_count
|
||||
for k,x in bufsts.items():
|
||||
rawbufs[k] = device.buffer(max(y.st.size() for y in x), x[0].dtype)
|
||||
assert all(x is not None for x in rawbufs)
|
||||
|
||||
# linearize
|
||||
preopt = lin.colored_shape()
|
||||
lin.hand_coded_optimizations()
|
||||
postopt = lin.colored_shape()
|
||||
lin.linearize()
|
||||
|
||||
# example var vals
|
||||
var_vals = {k:k.min for k in var_vals_from_ast(ast)}
|
||||
|
||||
# time
|
||||
prg = device.to_program(lin)
|
||||
tm = min([prg(rawbufs, var_vals, force_wait=True) for _ in range(10)])
|
||||
atm.append(tm)
|
||||
|
||||
# print
|
||||
#print_tree(ast)
|
||||
#for u in lin.uops: print(u)
|
||||
gflops = sym_infer(lin.info.flops, var_vals)*1e-9/tm
|
||||
agflops.append(gflops)
|
||||
if tm*1e6 > 100:
|
||||
print(f"{len(lin.uops)} uops, {lin.global_size} {lin.local_size}, {tm*1e6:.2f} us {gflops:.2f} GFLOPS", preopt, "->", postopt)
|
||||
|
||||
print(f"all kernels ran in {sum(atm)*1e3:.2f} ms")
|
||||
|
||||
if getenv("SHOW"):
|
||||
import matplotlib.pyplot as plt
|
||||
#plt.hist(agflops, bins=100)
|
||||
#plt.yscale('log')
|
||||
plt.scatter(atm, agflops)
|
||||
plt.xscale('log')
|
||||
plt.show()
|
|
@ -6,7 +6,7 @@ except ImportError:
|
|||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
|
@ -45,7 +45,9 @@ def str_dtype(dtyp):
|
|||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
logops = open(getenv("LOGOPS", ""),"a") if getenv("LOGOPS", "") else None
|
||||
def log_schedule_item(si: ScheduleItem):
|
||||
if logops and si.ast.op not in LoadOps: logops.write(str(si.ast)+"\n")
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
|
||||
|
|
|
@ -77,7 +77,7 @@ class DType(NamedTuple):
|
|||
name: str
|
||||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
||||
sz: int = 1
|
||||
def __repr__(self): return f"dtypes.{self.name}"
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
|
||||
|
||||
# dependent typing?
|
||||
class ImageDType(DType):
|
||||
|
@ -137,6 +137,7 @@ class dtypes:
|
|||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
|
|
|
@ -38,7 +38,8 @@ class Node:
|
|||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return "<"+self.key+">"
|
||||
def __repr__(self): return self.render(ctx="REPR")
|
||||
def __str__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return self.hash
|
||||
def __bool__(self): return not (self.max == self.min == 0)
|
||||
def __eq__(self, other:object) -> bool:
|
||||
|
@ -315,7 +316,7 @@ def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if
|
|||
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(a, int): return a
|
||||
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
|
||||
assert isinstance(ret, NumNode)
|
||||
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
||||
return ret.b
|
||||
|
||||
# symbolic int
|
||||
|
@ -323,7 +324,7 @@ sint = Union[Node, int]
|
|||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}",
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})" if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
|
|
Loading…
Reference in New Issue