1
0
Fork 0

fix handcode_resnet50_opt.py (#2558)

pull/2555/head^2
chenyu 2023-12-01 20:51:21 -05:00 committed by GitHub
parent 86fbd413f3
commit 05a5357dd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 5 deletions

View File

@ -1,9 +1,10 @@
from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
@ -33,9 +34,7 @@ if __name__ == "__main__":
total_tm = 0
running_gflops = 0
for i,si in enumerate(sched):
# create output/input buffers (NOTE: bufs_from_lin is slower, so we don't use it. TODO: fix)
rawbufs = [device.buffer(si.out.st.size(), si.out.dtype)] + [device.buffer(x.st.size(), x.dtype) for x in si.inputs]
#rawbufs = bufs_from_lin(lin)
rawbufs = bufs_from_lin(Linearizer(si.ast))
# "linearize" the op into uops in different ways
lins:List[Linearizer] = []