1
0
Fork 0

external_test_onnx_backend

pull/594/head
George Hotz 2023-02-23 21:55:07 -08:00
parent edaf878339
commit d8b6f241f1
2 changed files with 48 additions and 9 deletions

View File

@ -127,13 +127,7 @@ def get_run_onnx(onnx_model):
elif n.op_type == "Conv":
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
assert 'dilations' not in opt or opt['dilations'] == (1,1)
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
# symmetric padding
# TODO: is this backward?
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
else:
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=(opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]) if 'pads' in opt else 0)
elif n.op_type in ["Add", "Sub", "Mul"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
@ -153,8 +147,8 @@ def get_run_onnx(onnx_model):
i = i+s
continue
elif n.op_type == "AveragePool":
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
ret = inp[0].avg_pool2d(opt['kernel_shape'])
#assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
ret = inp[0].avg_pool2d(opt['kernel_shape'], opt['strides'])
elif n.op_type == "MaxPool":
#assert opt['kernel_shape'] == opt['strides'], f"kernel_shape and stride mismatch {opt}"
#opt['kernel_shape'] = opt['strides']

View File

@ -0,0 +1,45 @@
import unittest
from onnx.backend.base import Backend, BackendRep
import onnx.backend.test
from typing import Any, Tuple
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
from extra.onnx import get_run_onnx
class TinygradModel(BackendRep):
def __init__(self, run_onnx, input_names):
super().__init__()
self.fxn = run_onnx
self.input_names = input_names
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
real_inputs = {k:v for k,v in zip(self.input_names, inputs)}
ret = self.fxn(real_inputs, debug=True)
ret = next(iter(ret.values())).numpy()
return (ret,)
class TinygradBackend(Backend):
@classmethod
def prepare(cls, onnx_model, device):
input_names = [inp.name for inp in onnx_model.graph.input]
print("prepare", cls, device, input_names)
run_onnx = get_run_onnx(onnx_model)
return TinygradModel(run_onnx, input_names)
@classmethod
def supports_device(cls, device: str) -> bool:
return device == "CPU"
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
# only the node tests for now
for x in backend_test.test_suite:
if 'OnnxBackendNodeModelTest' in str(type(x)):
backend_test.include(str(x).split(" ")[0])
globals().update(backend_test.enable_report().test_cases)
if __name__ == '__main__':
unittest.main()