Fork 0

53 lines
1.4 KiB

from functools import lru_cache
from .tensor import Device, Function, register
def compile_wrapper(ane, dat):
return ane.compile(dat)
def roundup(x, v):
return x + (v - x) % v
def compile_relu(ane, sz):
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
# TODO: make this all nice and once
# number of engines? (max 0x100)
l2_stride = max(0x100, roundup(sz * 2, 0x10))
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
# 0x1f4, 0x1f8?
# 0x214 = L2.ResultBase.Addr
dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride)
stride = roundup(sz * 2, 0x40)
dat = ane.filln(
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
"InputWidth": sz,
"OutputWidth": sz,
"InputRowStride": stride,
"InputPlaneStride": stride,
"InputDepthStride": stride,
"OutputRowStride": stride,
"OutputPlaneStride": stride,
"OutputDepthStride": stride,
return compile_wrapper(ane, bytes(dat))
class ReLU(Function):
def forward(ctx, input):
ret = ctx.ane.tensor(input.shape)
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
return ret
def backward(ctx, grad_output):
return 0
register("relu", ReLU, device=Device.ANE)