1
0
Fork 0
tinygrab/examples/handcode_resnet50_opt.py

86 lines
3.0 KiB
Python
Raw Permalink Normal View History

from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
2023-12-01 18:51:21 -07:00
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
2023-12-01 18:51:21 -07:00
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
if __name__ == "__main__":
2023-12-04 22:01:04 -07:00
mdl = ResNet50()
seen = set()
2023-12-04 22:01:04 -07:00
# the device we are optimizing for
device: Compiled = Device[Device.DEFAULT]
print(f"optimizing for {Device.DEFAULT}")
2023-12-04 22:01:04 -07:00
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
2023-12-04 22:01:04 -07:00
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
2023-12-04 22:01:04 -07:00
# focus on one kernel
if getenv("KERNEL", -1) >= 0:
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
2023-12-04 22:01:04 -07:00
# work with the schedule
total_tm = 0
running_gflops = 0
for i, si in enumerate(sched):
rawbufs = bufs_from_lin(Linearizer(si.ast))
2023-12-04 22:01:04 -07:00
# "linearize" the op into uops in different ways
lins: List[Linearizer] = []
2023-12-04 22:01:04 -07:00
# always try hand coded opt
lin = Linearizer(si.ast, device.linearizer_opts)
lin.hand_coded_optimizations()
lins.append(lin)
2023-12-04 22:01:04 -07:00
# maybe try tensor cores
lin = Linearizer(si.ast, device.linearizer_opts)
if lin.apply_tensor_cores():
lins.append(lin)
2023-12-04 22:01:04 -07:00
# try a beam search
if getenv("BEAM"):
lin = Linearizer(si.ast, device.linearizer_opts)
lin = beam_search(
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
)
lins.append(lin)
2023-12-04 22:01:04 -07:00
# benchmark the programs
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = (
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
* 1e-9
/ tm
)
choices.append((tm, gflops, lin.linearize()))
2023-12-04 22:01:04 -07:00
# print all kernels
if DEBUG >= 1:
print(
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
print(
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
total_tm += tm
running_gflops += gflops * tm
print(
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
)