1
0
Fork 0
tinygrab/examples/handcode_resnet50_opt.py

86 lines
3.0 KiB
Python

from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
if __name__ == "__main__":
mdl = ResNet50()
seen = set()
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
print(f"optimizing for {Device.DEFAULT}")
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
# focus on one kernel
if getenv("KERNEL", -1) >= 0:
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
# work with the schedule
total_tm = 0
running_gflops = 0
for i, si in enumerate(sched):
rawbufs = bufs_from_lin(Linearizer(si.ast))
# "linearize" the op into uops in different ways
lins: List[Linearizer] = []
# always try hand coded opt
lin = Linearizer(si.ast, device.linearizer_opts)
lin.hand_coded_optimizations()
lins.append(lin)
# maybe try tensor cores
lin = Linearizer(si.ast, device.linearizer_opts)
if lin.apply_tensor_cores():
lins.append(lin)
# try a beam search
if getenv("BEAM"):
lin = Linearizer(si.ast, device.linearizer_opts)
lin = beam_search(
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
)
lins.append(lin)
# benchmark the programs
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = (
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
* 1e-9
/ tm
)
choices.append((tm, gflops, lin.linearize()))
# print all kernels
if DEBUG >= 1:
print(
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
print(
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
total_tm += tm
running_gflops += gflops * tm
print(
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
)