external_test_onnx_backend
parent
edaf878339
commit
d8b6f241f1
|
@ -127,13 +127,7 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
|
||||
# symmetric padding
|
||||
# TODO: is this backward?
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
|
||||
else:
|
||||
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=(opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]) if 'pads' in opt else 0)
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
|
@ -153,8 +147,8 @@ def get_run_onnx(onnx_model):
|
|||
i = i+s
|
||||
continue
|
||||
elif n.op_type == "AveragePool":
|
||||
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'])
|
||||
#assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'], opt['strides'])
|
||||
elif n.op_type == "MaxPool":
|
||||
#assert opt['kernel_shape'] == opt['strides'], f"kernel_shape and stride mismatch {opt}"
|
||||
#opt['kernel_shape'] = opt['strides']
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import unittest
|
||||
from onnx.backend.base import Backend, BackendRep
|
||||
import onnx.backend.test
|
||||
from typing import Any, Tuple
|
||||
|
||||
# pip3 install tabulate
|
||||
pytest_plugins = 'onnx.backend.test.report',
|
||||
|
||||
from extra.onnx import get_run_onnx
|
||||
|
||||
class TinygradModel(BackendRep):
|
||||
def __init__(self, run_onnx, input_names):
|
||||
super().__init__()
|
||||
self.fxn = run_onnx
|
||||
self.input_names = input_names
|
||||
|
||||
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
|
||||
real_inputs = {k:v for k,v in zip(self.input_names, inputs)}
|
||||
ret = self.fxn(real_inputs, debug=True)
|
||||
ret = next(iter(ret.values())).numpy()
|
||||
return (ret,)
|
||||
|
||||
class TinygradBackend(Backend):
|
||||
@classmethod
|
||||
def prepare(cls, onnx_model, device):
|
||||
input_names = [inp.name for inp in onnx_model.graph.input]
|
||||
print("prepare", cls, device, input_names)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
return TinygradModel(run_onnx, input_names)
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device: str) -> bool:
|
||||
return device == "CPU"
|
||||
|
||||
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
|
||||
|
||||
# only the node tests for now
|
||||
for x in backend_test.test_suite:
|
||||
if 'OnnxBackendNodeModelTest' in str(type(x)):
|
||||
backend_test.include(str(x).split(" ")[0])
|
||||
|
||||
globals().update(backend_test.enable_report().test_cases)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue