1
0
Fork 0
tinygrab/extra/onnx_ops.py

1835 lines
59 KiB
Python

from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx import TensorProto
import io
import os
import numpy as np
import functools
from typing import Union, Tuple, Optional, List, Any
import math
tensor_methods = {
"Neg",
"Reciprocal",
"Sqrt",
"Sign",
"Abs",
"Exp",
"Log",
"Mish",
"Sin",
"Cos",
"Tan",
"Relu",
"Sigmoid",
"Tanh",
"MatMul",
"Floor",
"Ceil",
"Tanh",
"Softplus",
"HardSwish",
"Where",
"Mul",
"Sinh",
"Cosh",
"Softsign",
"Asinh",
"Acosh",
"Atanh",
}
# **************** Free Ops ****************
def Identity(input: Tensor):
return input
def Add(input: Tensor, other: Tensor, broadcast=None):
return (
input + other
if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType)
else (input + other).cast(input.dtype)
)
def Sub(input: Union[Tensor, Any], other: Tensor):
return input - other # some test has input as int
def Div(input: Tensor, other: Tensor):
return (
input / other
if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType)
else input.div(other).floor()
) # TODO: this has dtype issues
def Pow(input: Tensor, other: Tensor):
return (input.float() ** other.float()).cast(
input.dtype
) # TODO: this has dtype issues
def Less(x: Tensor, y: Tensor):
return (x < y).cast(dtypes.bool)
def LessOrEqual(x: Tensor, y: Tensor):
return (x <= y).cast(dtypes.bool)
def Greater(x: Tensor, y: Tensor):
return (x > y).cast(dtypes.bool)
def GreaterOrEqual(x: Tensor, y: Tensor):
return (x >= y).cast(dtypes.bool)
def Equal(x: Tensor, y: Tensor):
return (x == y).cast(dtypes.bool)
def Max(*data_0):
return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0):
return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0):
return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0):
return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def Cast(input: Tensor, to):
return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
# **************** Simple Ops ****************
def Constant(
value: Tensor = None,
value_float=None,
value_floats=None,
value_int=None,
value_ints=None,
value_string=None,
value_strings=None,
):
if value:
return value
elif value_float:
return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
elif value_floats:
return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
elif value_int:
return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
elif value_ints:
return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
elif value_string or value_strings:
raise NotImplementedError(
f"value_string or value_strings not implemented for Constant op"
)
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5):
return (alpha * input + beta).clip(0, 1)
def Gelu(x: Tensor, approximate=None):
return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x / math.sqrt(2)))
def Celu(x: Tensor, alpha=1.0):
return x.celu(alpha)
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875):
return gamma * (X.relu() - (-alpha * X.exp() + alpha).relu())
def PRelu(X: Tensor, slope: Tensor):
slope = (
slope[0] if slope.shape[-1] != X.shape[-1] else slope
) # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
def LeakyRelu(X: Tensor, alpha=0.01):
return X.leakyrelu(alpha)
def ThresholdedRelu(X: Tensor, alpha=1.0):
return (X - alpha).relu() + (X - alpha).relu().sign() * alpha
def Softmax_1(input: Tensor, axis=1):
return input.softmax(axis)
def Softmax_13(input: Tensor, axis=-1):
return input.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
def LogSoftmax(input: Tensor, axis=-1):
return input.log_softmax(axis)
def Clip(input: Tensor, min=None, max=None):
return input.clip(
float("-inf") if min is None else min, float("inf") if max is None else max
)
# NOTE ReduceProd would require a new llop
def _axes(axes, noop_with_empty_axes):
return (
[int(x) for x in safe_numpy(axes)]
if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,))
else ([] if noop_with_empty_axes else None)
)
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def GlobalAveragePool(X: Tensor):
return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X: Tensor):
return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def OptionalHasElement(x: Tensor = None):
return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
def OptionalGetElement(x: Tensor = None):
return x if x is not None else Tensor([], dtype=dtypes.float32)
def Tile(input: Tensor, repeats):
return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta):
return Tensor.arange(
start=int(safe_numpy(start).item()),
stop=int(safe_numpy(limit).item()),
step=int(safe_numpy(delta).item()),
).cast(dtype=start.dtype)
def Shape(data: Tensor, end=None, start=0):
return Tensor(
list(data.shape)[start:end],
dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64,
) # TODO: really?
def Size(data: Tensor):
return prod(data if isinstance(data, list) else data.shape)
def Flatten(input: Tensor, axis=1):
return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None):
return data.reshape(
[int(x) if x != 0 else data.shape[i] for i, x in enumerate(safe_numpy(shape))]
)
def Shrink(input: Tensor, bias=0.0, lambd=0.5):
return (input < -lambd) * (input + bias) + (input > lambd) * (input - bias)
def And(x: Tensor, y: Tensor):
return (x == y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
def Or(x: Tensor, y: Tensor):
return (x == y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool)
def Xor(x: Tensor, y: Tensor):
return (
(x == y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
)
def Not(x: Tensor):
return (
(x == 1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
)
def Asin(x):
return Atan(x / Tensor.sqrt(1 - x * x))
def Acos(x: Tensor):
negate = x < 0
x = x.abs()
ret = (
(((-0.0187293 * x) + 0.0742610) * x - 0.2121144) * x + 1.5707288
) * Tensor.sqrt(1.0 - x)
ret = ret - 2 * negate * ret
return negate * 3.14159265358979 + ret
def Atan(y: Tensor):
x = Tensor.ones(y.shape)
t3 = x
t1 = y.abs()
t0 = (t3 > t1).where(t3, t1)
t1 = (t3 < t1).where(t3, t1)
t3 = t1 / t0
t4 = t3 * t3
t0 = (
(((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4
- 0.332994597
) * t4 + 0.999995630
t3 = t0 * t3
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
return (y < 0).where(-t3, t3)
def Trilu(x: Tensor, k: Union[Tensor, int] = 0, upper=1):
k = (
int(k.numpy().item()) if k != 0 else 0
) # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Squeeze(input: Tensor, axes):
if isinstance(axes, Tensor):
axes = safe_numpy(axes)
axes = [int(x) if x >= 0 else int(x + input.ndim) for x in axes]
return input.reshape([s for i, s in enumerate(input.shape) if i not in axes])
def Unsqueeze(data: Tensor, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
new_shape = [1] * (len(data.shape) + len(axes))
ptr = iter(data.shape)
for i in range(len(new_shape)):
if i not in axes:
new_shape[i] = next(ptr)
return data.reshape(new_shape)
def Binarizer(input, threshold=0.0):
return input > threshold
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis
m = x == (
x.max(axis=axis, keepdim=keepdims)
if keepdims
else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis)
)
c = (
Tensor.arange(x.shape[axis]).reshape(
*[1] * (axis), x.shape[axis], *[1] * (x.ndim - axis - 1)
)
* m
)
return c.max(axis=axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0):
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Elu(input: Tensor, alpha=1.0):
return input.elu(alpha=alpha)
def Concat(*inputs: List[Tensor], axis):
return inputs[0].cat(*inputs[1:], dim=axis)
def Transpose(input: Tensor, perm=None):
return input.permute(
order=list(range(len(input.shape))[::-1]) if perm is None else perm
)
# NOTE: since we only have one type, this is valid!
def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input
def ConstantOfShape(input, value: Tensor = None):
if value is None:
value = Tensor([0.0])
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0] != 0 else 1)
# TODO: abstract out the broadcast logic in tensor
def Expand(input: Tensor, shape):
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
# copied from _broadcasted
x_shape, y_shape = [
([1] * (max(len(x_shape), len(y_shape)) - len(t_shape)) + list(t_shape))
for t_shape in [x_shape, y_shape]
]
shape_ret = tuple(max(sx, sy) for sx, sy in zip(x_shape, y_shape))
return input.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(
A: Tensor,
B: Tensor,
C: Tensor = None,
alpha=1.0,
beta=1.0,
transA=0,
transB=0,
broadcast=0,
):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None:
ret += beta * (
C
if broadcast == 0
else C.reshape(
[-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]
)
)
return ret
# works with Tensors.ndim != 4
def _batchnorm(
self: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
mean: Tensor,
invstd: Tensor,
):
shape = [1, -1] + [1] * (self.ndim - 2)
x = self - mean.reshape(shape=shape)
if weight:
x = x * weight.reshape(shape=shape)
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=shape)) if bias else ret
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
def BatchNormalization(
X: Tensor,
scale,
B,
input_mean,
input_var,
epsilon=1e-05,
momentum=0.9,
training_mode=0,
spatial=1,
is_test=0,
):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0, 2, 3))
y = x_detached - current_mean.reshape(shape=[1, -1, 1, 1])
current_var = (y * y).mean(axis=(0, 2, 3))
current_invstd = current_var.add(epsilon).pow(-0.5)
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return (
_batchnorm(X, scale, B, current_mean, current_invstd),
running_mean,
running_var,
)
else:
invstd = (input_var + epsilon) ** -0.5
return _batchnorm(X, scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
return (
x.sub(mean)
.mul(scale.reshape(shape=[-1, 1, 1]))
.mul(invstd)
.add(bias.reshape(shape=[-1, 1, 1]))
)
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(
i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape))
)
mean = x.mean(axis=axis, keepdim=True)
return (
x.layernorm(axis, epsilon).mul(scale).add(bias),
mean,
(x.sub(mean))
.pow(2)
.mean(axis=axis, keepdim=True)
.add(epsilon)
.sqrt()
.reciprocal(),
)
def GroupNormalization(
x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05
):
return (
x.reshape(x.shape[0], num_groups, -1)
.layernorm(axis=-1, eps=epsilon)
.mul(scale.unsqueeze(-1))
.add(bias.unsqueeze(-1))
.reshape(x.shape)
)
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
def _format_padding(onnx_pads, ndims=None, axes=None):
if ndims and len(onnx_pads) // 2 != ndims:
onnx_pads = (
onnx_pads * ndims
) # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
if ndims is None:
ndims = len(onnx_pads) // 2
if axes is None:
axes = list(range(ndims))
num_axes = len(axes)
np_pads = [(0, 0)] * ndims
for i in range(num_axes):
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
return np_pads
def _padding(
X: Tensor,
pads=None,
auto_pad="NOTSET",
axes=None,
constant_value=0.0,
strides=None,
kernel_shape=None,
dilations=None,
ceil_mode=0,
):
if auto_pad != "NOTSET":
pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
elif ceil_mode and auto_pad == "NOTSET": # stupid ceil_mode case
if strides is not None:
strides = (
[strides] * len(kernel_shape)
if isinstance(strides, int)
else strides
if strides
else [1] * len(kernel_shape)
)
if dilations is not None:
dilations = [1] * len(kernel_shape) if dilations == 1 else dilations
out_spatial_shape = [
math.ceil((sh - dil * (ker - 1) - 1) / st + 1)
if ceil_mode
else math.floor((sh - dil * (ker - 1) - 1) / st + 1)
for sh, st, ker, dil in zip(
X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations
)
]
pad_shape = [
(osh - 1) * st + ((ks - 1) * dil + 1) - ish
for osh, st, ks, dil, ish in zip(
out_spatial_shape,
strides,
kernel_shape,
dilations,
X.shape[-len(kernel_shape) :],
)
]
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape])
pads = pad_shape[::2] + pad_shape[1::2]
if pads is None:
return X
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
return X.pad(tuple(pads), value=constant_value)
def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
strides = (
[strides] * len(kernel_shape)
if isinstance(strides, int)
else strides
if strides
else [1] * len(kernel_shape)
)
dilations = [1] * len(kernel_shape) if dilations == 1 else dilations
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
pad_shape = [
(math.ceil(sh / st) - 1) * st + ((ks - 1) * di + 1) - sh
for sh, st, ks, di in zip(
X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations
)
]
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape])
return (
pad_shape[::2] + pad_shape[1::2]
if auto_pad == "SAME_UPPER"
else pad_shape[1::2] + pad_shape[::2]
)
else:
raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
def Pad(
x: Tensor,
pads: Union[Tensor, Tuple[int, ...]],
constant_value: Tensor = None,
axes: Tensor = None,
mode="constant",
value: float = 0.0,
):
constant_value = (
value if constant_value is None else float(safe_numpy(constant_value)[0])
)
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
seq_pads = [math.ceil(i) for i in seq_pads]
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
base_shape = x.shape
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
if mode == "wrap":
repeat_args = [
math.ceil(dim[0] / sh) + math.ceil(dim[1] / sh) + 1
for dim, sh in zip(pads, base_shape)
]
new_shape = [s * r for s, r in zip(base_shape, repeat_args)]
shrink_args = [
(
sh - dim[0] % sh if dim[0] % sh != 0 else 0,
nsh - (sh - dim[1] % sh) if dim[1] % sh != 0 else nsh,
)
for dim, sh, nsh in zip(pads, base_shape, new_shape)
]
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
elif mode == "reflect":
for i, s in enumerate(x.shape):
if pads[i] == (0, 0):
continue
elif pads[i][0] and not pads[i][1]:
x = x.flip(i).shrink(
tuple(
[
(0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1)
for i_, s_ in enumerate(x.shape)
]
)
).pad(
tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)])
) + x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
elif not pads[i][0] and pads[i][1]:
x = x.flip(i).shrink(
tuple(
[
(0, s_) if i_ != i else (1, pads[i][1] + 1)
for i_, s_ in enumerate(x.shape)
]
)
).pad(
tuple([(0, 0) if i_ != i else (s, 0) for i_ in range(x.ndim)])
) + x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
else:
x = (
x.flip(i)
.shrink(
tuple(
[
(0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1)
for i_, s_ in enumerate(x.shape)
]
)
)
.pad(
tuple(
[
(0, 0) if i_ != i else (0, s + pads[i][1])
for i_ in range(x.ndim)
]
)
)
+ x.flip(i)
.shrink(
tuple(
[
(0, s_) if i_ != i else (1, pads[i][1] + 1)
for i_, s_ in enumerate(x.shape)
]
)
)
.pad(
tuple(
[
(0, 0) if i_ != i else (s + pads[i][0], 0)
for i_ in range(x.ndim)
]
)
)
+ x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
)
return x
elif mode == "edge":
for i, s in enumerate(x.shape):
if pads[i] == (0, 0):
continue
elif pads[i][0] and not pads[i][1]:
x = x.shrink(
tuple(
[
(0, s_) if i_ != i else (0, 1)
for i_, s_ in enumerate(x.shape)
]
)
).expand(
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
).pad(
tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)])
) + x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
elif not pads[i][0] and pads[i][1]:
x = x.shrink(
tuple(
[
(0, s_) if i_ != i else (s_ - 1, s_)
for i_, s_ in enumerate(x.shape)
]
)
).expand(
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
).pad(
tuple(
[
(0, 0) if i_ != i else (s + pads[i][0], 0)
for i_ in range(x.ndim)
]
)
) + x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
else:
x = (
x.shrink(
tuple(
[
(0, s_) if i_ != i else (0, 1)
for i_, s_ in enumerate(x.shape)
]
)
)
.expand(
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
)
.pad(
tuple(
[
(0, 0) if i_ != i else (0, s + pads[i][1])
for i_ in range(x.ndim)
]
)
)
+ x.shrink(
tuple(
[
(0, s_) if i_ != i else (s_ - 1, s_)
for i_, s_ in enumerate(x.shape)
]
)
)
.expand(
[pads[i][1] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
)
.pad(
tuple(
[
(0, 0) if i_ != i else (s + pads[i][0], 0)
for i_ in range(x.ndim)
]
)
)
+ x.pad(
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
)
)
return x
elif mode == "constant":
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
def AveragePool(
X: Tensor,
kernel_shape,
auto_pad="NOTSET",
ceil_mode=0,
count_include_pad=0,
dilations=1,
pads=None,
strides=1,
):
pixel_axes = tuple(range(len(X.shape)))[2:]
ret = _padding(
X,
pads,
auto_pad,
axes=pixel_axes,
strides=strides,
kernel_shape=kernel_shape,
dilations=dilations,
ceil_mode=ceil_mode,
).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
if count_include_pad:
return ret
else:
div = _padding(
Tensor.ones(*X.shape),
pads,
auto_pad,
axes=pixel_axes,
strides=strides,
kernel_shape=kernel_shape,
dilations=dilations,
ceil_mode=ceil_mode,
).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
return ret / div
def MaxPool(
X: Tensor,
kernel_shape,
auto_pad="NOTSET",
ceil_mode=0,
dilations=1,
pads=None,
storage_order=0,
strides=1,
):
ret = _padding(
X,
pads,
auto_pad,
constant_value=float("-inf"),
axes=tuple(range(len(X.shape)))[2:],
strides=strides,
kernel_shape=kernel_shape,
dilations=dilations,
ceil_mode=ceil_mode,
)
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
ret_len, X_len = ret.numel(), X.numel()
indices = (
(
(
ret.flatten().unsqueeze(1).expand(ret_len, X_len)
== X.flatten().reshape(1, X_len).expand(ret_len, X_len)
)
* Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)
)
.sum(1)
.reshape(ret.shape)
.cast(dtypes.int64)
)
if storage_order:
indices = indices.transpose(indices.ndim - 2, indices.ndim - 1)
return ret, indices
def MaxUnpool(
xT: Tensor,
xI: Tensor,
outshape: Tensor = None,
kernel_shape=None,
pads=None,
strides=None,
):
out_sh = [
(ks // 2) * 2 + st * inps
for inps, st, ks in zip(xI.shape, strides, kernel_shape)
]
outlength = prod(out_sh)
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
arange = (
Tensor.arange(outlength, requires_grad=False)
.reshape(1, outlength)
.expand(xI.shape)
)
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
if outshape is not None:
outshape = safe_numpy(outshape).tolist()
if outshape != ret.shape:
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
pad_args = [
diff[0] // 2,
diff[1] // 2,
diff[0] - diff[0] // 2,
diff[1] - diff[1] // 2,
]
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
return ret
def Conv(
X: Tensor,
W: Tensor,
B=None,
auto_pad="NOTSET",
dilations=1,
group=1,
kernel_shape=None,
pads=None,
strides=1,
):
if auto_pad != "NOTSET":
padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
else:
padding = (
[
p
for ps in zip(
pads[: len(pads) // 2][::-1], pads[len(pads) // 2 :][::-1]
)
for p in ps
]
if pads is not None
else 0
) # reorder padding
return X.conv2d(
W, B, stride=strides, groups=group, dilation=dilations, padding=padding
)
def ConvTranspose(
X: Tensor,
W: Tensor,
B=None,
auto_pad="NOTSET",
dilations=1,
group=1,
kernel_shape=None,
pads=None,
output_shape=None,
output_padding=0,
strides=1,
):
if kernel_shape is None:
kernel_shape = W.shape[2:]
if isinstance(strides, int):
strides = [strides] * (W.ndim - 2)
if isinstance(dilations, int):
dilations = [dilations] * (W.ndim - 2)
if isinstance(output_padding, int):
output_padding = [output_padding] * (W.ndim - 2)
out_sh = (
[
st * (xs - 1) + (ks - 1) * di + 1
if n < 2
else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1]
for n, (st, xs, ks, di) in enumerate(
zip(strides, X.shape[2:], kernel_shape, dilations)
)
]
if output_shape is not None or auto_pad != "NOTSET"
else []
)
if pads is None:
if output_shape is None:
output_shape = [xs * st for xs, st in zip(X.shape[2:], strides)]
if auto_pad == "NOTSET":
pads = [0, 0] * (X.ndim - 2)
else:
total_padding = [
st * (ish - 1) + pad + ((ks - 1) * dil + 1) - osh
for st, ish, pad, ks, dil, osh in zip(
strides,
X.shape[2:],
output_padding,
kernel_shape,
dilations,
output_shape,
)
]
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in total_padding])
pads = (
pad_shape[::2] + pad_shape[1::2]
if auto_pad == "SAME_UPPER"
else pad_shape[1::2] + pad_shape[::2]
)
else:
if output_shape is None:
output_shape = [
st * (xs - 1) + (ks - 1) * di + 1
if n < 2
else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1]
for n, (st, xs, ks, di) in enumerate(
zip(strides, X.shape[2:], kernel_shape, dilations)
)
]
if out_sh:
output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
return X.conv_transpose2d(
W,
B,
stride=strides,
groups=group,
dilation=dilations,
padding=pads if pads is not None else 0,
output_padding=output_padding,
)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(ratio, Tensor) and not ratio.shape:
ratio = safe_numpy(
ratio
) # ratio and tensor is passed in as Tensor with shape: ()
if isinstance(training_mode, Tensor) and not training_mode.shape:
training_mode = safe_numpy(training_mode)
if not training_mode:
return data, Tensor.ones(
*data.shape, dtype=dtypes.bool
) # if mask is requested as output it will contain all True's.
rng = np.random.RandomState(seed)
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
mask = Tensor(
(rng.random(data.shape) >= ratio), requires_grad=False, device=data.device
)
return data * mask * (1 / (1.0 - ratio)), mask
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
bs, c, iy, ix = input.shape
return input / input.mul(input).reshape(bs, 1, c, iy * ix).pad2d(
(0, 0, (size - 1) // 2, size // 2)
).avg_pool2d((size, 1), 1).reshape(bs, c, iy, ix).mul(alpha).add(bias).pow(beta)
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)
def NegativeLogLikelihoodLoss(
input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"
):
target = target.cast(dtypes.float32)
N, C, i_shape = input.shape[0], input.shape[1], input.shape
t_shape = target.shape
if len(input.shape) != 3:
input = input.reshape((N, C, -1))
target = target.reshape((N, -1))
if weight is not None:
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
weight = (mask * weight).sum(axis=-1)
if ignore_index is not None:
cond = target == ignore_index
weight = (
cond.where(0, weight)
if weight is not None
else cond.where(Tensor.zeros(*target.shape), 1)
)
mask = target[:, None, :] == Tensor.arange(C).reshape(
[1, C] + [1] * (len(input.shape) - 2)
)
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean":
return loss.mean() if weight is None else loss.sum() / weight.sum()
elif reduction == "sum":
return loss.sum()
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
def SoftmaxCrossEntropyLoss(
scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"
):
N, C, *s_dimensions = scores.shape
if ignore_index is not None:
labels = (labels == ignore_index).where(C + 1, labels)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(
1, C, *[1] * len(s_dimensions)
)
y = scores.log_softmax(axis=1)
if weights is not None:
weights = weights.__getitem__(
tuple([labels, *[slice(None)] * (weights.ndim - 1)])
)
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
if reduction == "mean":
loss = (
loss.sum() / (loss == 0).where(0, 1).sum()
if weights is None
else loss.sum() / weights.sum()
)
elif reduction == "sum":
loss = loss.sum()
return loss, y
def ArrayFeatureExtractor(input: Tensor, indices: Tensor):
return input.__getitem__(
tuple(
[
slice(None) if i != (input.ndim - 1) else indices
for i in range(input.ndim)
]
)
)
def Gather(input: Tensor, indices: Tensor, axis=0):
if (
indices.numel() < 9
): # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
input_sh = list(input.shape)
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis + 1 :]
if indices.ndim > 1:
indices = indices.flatten()
indices = (
[int(safe_numpy(indices))]
if indices.shape == ()
else [
input_sh[axis] + int(x) if x < 0 else int(x)
for x in safe_numpy(indices)
]
)
args = [
[(0, x) if j != axis else (i, i + 1) for j, x in enumerate(input_sh)]
for i in indices
]
return (
input.shrink(arg=tuple(args[0]))
.cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis)
.reshape(ret_shape)
)
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return input.__getitem__(
tuple([slice(None) if i != axis else indices for i in range(input.ndim)])
)
def GatherElements(input: Tensor, indices: Tensor, axis):
indices = (
indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis]
+ indices
)
return input.gather(indices, axis)
def _round(x: Tensor, n: float, equidistant_case="round_down") -> Tensor:
def _and(cond1, cond2):
return ((cond1 + cond2) == 2).where(1, 0)
assert n <= 1, f"n:{n} shouldn't be larger than 1"
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
b = (b >= 0).where(b + n, b - n)
if equidistant_case == "round_down":
return (x > b).where(b + 1 - n, b - n)
elif equidistant_case == "round_up":
return (x >= b).where(b + 1 - n, b - n)
elif equidistant_case == "round_to_even":
x_ceil_fraction = x.ceil() / 2
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
x = (_and(x == b, cond_ceil_even)).where(x + 1 - n, x)
x = (x > b).where(b + 1 - n, b - n)
return x
def Round(X: Tensor):
return _round(X, 0.5, "round_to_even")
# TODO clean this up, it's taking the longest in CI
def Resize(
X: Tensor,
roi=None,
scales=None,
sizes=None,
antialias=0,
axes=None,
coordinate_transformation_mode="half_pixel",
cubic_coeff_a=-0.75,
exclude_outside=0,
extrapolation_value=0.0,
keep_aspect_ratio_policy="stretch",
mode="nearest",
nearest_mode="round_prefer_floor",
):
def _nearest_gather(X: Tensor, x_out, y_out):
return X[:, :, y_out, :][:, :, :, x_out]
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
if nearest_mode == "round_prefer_floor":
ret = _round(x_resized, 0.5, "round_down")
elif nearest_mode == "round_prefer_ceil":
ret = _round(x_resized, 0.5, "round_up")
elif nearest_mode == "floor":
ret = x_resized.floor()
elif nearest_mode == "ceil":
ret = x_resized.ceil()
return ret.clip(0, x_len - 1)
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
if coordinate_transformation_mode == "half_pixel":
x_out = (x_out + 0.5) / Tensor(
scales_lol[-1]
) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
y_out = (y_out + 0.5) / Tensor(scales_lol[-2]) - 0.5
elif coordinate_transformation_mode == "align_corners":
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
elif coordinate_transformation_mode == "asymmetric":
x_out = x_out / scales_lol[-1]
y_out = y_out / scales_lol[-2]
elif coordinate_transformation_mode == "half_pixel_symmetric":
x_out = (
X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1])
+ (x_out + 0.5) / scales_lol[-1]
- 0.5
)
y_out = (
X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2])
+ (y_out + 0.5) / scales_lol[-2]
- 0.5
)
elif coordinate_transformation_mode == "pytorch_half_pixel":
x_out = (
(x_out + 0.5) / scales_lol[-1] - 0.5
if output_shape[-1] > 1
else Tensor([0])
)
y_out = (
(y_out + 0.5) / scales_lol[-2] - 0.5
if output_shape[-2] > 1
else Tensor([0])
)
elif coordinate_transformation_mode == "tf_crop_and_resize":
x_out = (
roi[-1][0] * (X.shape[-1] - 1)
+ x_out
* (
(roi[-1][1] - roi[-1][0])
* (X.shape[-1] - 1)
/ (output_shape[-1] - 1)
)
if output_shape[-1] > 1
else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
)
y_out = (
roi[-2][0] * (X.shape[-2] - 1)
+ y_out
* (
(roi[-2][1] - roi[-2][0])
* (X.shape[-2] - 1)
/ (output_shape[-2] - 1)
)
if output_shape[-2] > 1
else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
)
return x_out.clip(0, X.shape[-1] - 1), y_out.clip(0, X.shape[-2] - 1)
if roi is not None:
roi = safe_numpy(roi)
roi = [(st, ed) for st, ed in zip(roi[: len(roi) // 2], roi[len(roi) // 2 :])]
roi_ = [(1, 1)] * 4
if axes is not None:
for a, r in zip(axes, roi):
roi_[a] = r
roi = roi_
if scales is not None:
scales = safe_numpy(scales).tolist()
if axes is not None:
scales_ = [1] * X.ndim
for a, s in zip(axes, scales):
scales_[a] = s
scales = scales_
elif sizes is not None:
sizes = [int(i) for i in safe_numpy(sizes)]
scales = []
if axes is not None:
sizes_ = [1] * X.ndim
for a, s in zip(axes, sizes):
sizes_[a] = s
scales.append(s / X.shape[a])
sizes = sizes_
else:
scales = [si / xs for xs, si in zip(X.shape, sizes)]
if keep_aspect_ratio_policy == "not_larger":
scale = min(scales)
sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
elif keep_aspect_ratio_policy == "not_smaller":
scale = max(scales)
sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up")
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
output_shape = (
sizes if sizes else [math.floor(x * s) for x, s in zip(X.shape, scales)]
)
output_shape_ = sizes if sizes else [x * s for x, s in zip(X.shape, scales)]
scales_lol = [os / xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1])
y_out = Tensor.arange(output_shape[-2])
if mode == "nearest":
x_out, y_out = _coordinate_transformation(
x_out, y_out, output_shape, scales_lol, roi
)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
return _nearest_gather(X, x_out, y_out)
elif mode == "linear":
x_out, y_out = _coordinate_transformation(
x_out, y_out, output_shape_, scales, roi
)
ret = []
for y in safe_numpy(y_out):
for x in safe_numpy(x_out):
x_floor, y_floor = int(x), int(y)
y_shrink = (
(0, X.shape[2])
if X.shape[2] == 1
else (y_floor, y_floor + 2)
if y != y_floor
else (y_floor, y_floor + 1)
)
x_shrink = (
(x_floor, x_floor + 2) if x != x_floor else (x_floor, x_floor + 1)
)
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
corners = safe_numpy(X.shrink(shrink_args))
x1, x2, y1, y2 = x_floor, x_floor + 1, y_floor, y_floor + 1
if (
x == x_floor and y == y_floor
): # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
ret.append(corners[0, 0, 0, 0])
elif x == x_floor:
ret.append(
(
corners[0, 0, 0, 0] * (y2 - y)
+ corners[0, 0, 1, 0] * (y - y1)
)
/ (y2 - y1)
)
elif y == y_floor:
ret.append(
(
corners[0, 0, 0, 0] * (x2 - x)
+ corners[0, 0, 0, 1] * (x - x1)
)
/ (x2 - x1)
)
else:
ret.append(
(
corners[0, 0, 0, 0] * (x2 - x) * (y2 - y)
+ corners[0, 0, 0, 1] * (x - x1) * (y2 - y)
+ corners[0, 0, 1, 0] * (x2 - x) * (y - y1)
+ corners[0, 0, 1, 1] * (x - x1) * (y - y1)
)
/ ((x2 - x1) * (y2 - y1))
)
return Tensor(ret).reshape(output_shape)
elif mode == "cubic":
raise Exception("cubic interpolation is not implemented")
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
if not axes:
axes = list(range(input.ndim))
shrink_arg = [(0, i) for i in input.shape]
pad_arg = [(0, 0) for _ in range(input.ndim)]
shape = safe_numpy(shape).tolist()
for s, x in zip(shape, axes):
if s < input.shape[x]:
shrink_arg[x] = (
(input.shape[x] // 2 - s // 2, input.shape[x] // 2 + s // 2)
if s % 2 == 0
else (input.shape[x] // 2 - s // 2 - 1, input.shape[x] // 2 + s // 2)
)
elif s > input.shape[x]:
pad_arg[x] = (
((s - input.shape[x]) // 2, (s - input.shape[x]) // 2)
if (s - input.shape[x]) % 2 == 0
else ((s - input.shape[x]) // 2, (s - input.shape[x]) // 2 + 1)
)
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = int(safe_numpy(depth).item())
indices, rank = (indices < 0).where(indices + depth, indices), len(indices.shape)
if axis < 0:
axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis:rank]
cond = indices[:, None] == Tensor.arange(depth).reshape(
(1,) * len(ls) + (depth,) + (1,) * len(rs)
)
return cond.where(values[1], values[0]).cast(values.dtype)
def Erf(x: Tensor):
sign = x.sign()
x = x.abs()
t = 1.0 / (1.0 + 0.3275911 * x)
term1 = 0.254829592 * t
term2 = -0.284496736 * t**2
term3 = 1.421413741 * t**3
term4 = -1.453152027 * t**4
term5 = 1.061405429 * t**5
y = term1 + term2 + term3 + term4 + term5
return sign * (1.0 - y * Tensor.exp(-x * x))
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
inp = inp.flatten()
axis = 0
axis = axis + inp.ndim if axis < 0 else axis
con_np = safe_numpy(condition)
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
return inp.__getitem__(
tuple([slice(None) if i != axis else con for i in range(inp.ndim)])
)
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
def EyeLike(x: Tensor, dtype=None, k=0):
if dtype is None:
dtype = x.dtype
else:
dtype = type_map[dtype]
shape = x.shape
dim = min(x.shape)
if shape[0] == shape[1]:
return Tensor.eye(dim=dim, dtype=dtype)
else:
diff = (shape[0] - dim, shape[1] - dim)
padarg = tuple([(d, d) if d == 0 else (k, d - k) for d in diff])
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
def Upsample(X, scales, mode):
return Resize(X=X, scales=scales, mode=mode)
# Needs work
def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
ret = (
(x == float("inf")) * detect_positive
+ (x == float("-inf")) * detect_negative
+ Tensor.zeros(*x.shape)
)
return ret.cast(dtypes.bool)
def DequantizeLinear(
x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1
):
axis = axis + x.ndim if axis < 0 else axis
x = x.cast(dtypes.float)
if x_zero_point.__class__ is Tensor:
x_zero_point.cast(dtypes.float)
x_sc = x_scale.reshape(
*[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim)
)
x_zer = (
x_zero_point.reshape(
*[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim)
)
if isinstance(x_zero_point, Tensor)
else x_zero_point
)
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
# Needs work
def IsNaN(x: Tensor):
return (x < float("-inf")).cast(dtypes.bool)
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
# without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
try:
import PIL.Image
except ImportError as e:
raise ImportError(
"Pillow must be installed to use the reference implementation of the ImageDecoder operator"
) from e
img = PIL.Image.open(io.BytesIO(safe_numpy(encoded_stream).tobytes()))
if pixel_format == "BGR":
return Tensor(np.array(img))[:, :, ::-1]
elif pixel_format == "RGB":
return Tensor(np.array(img))
elif pixel_format == "Grayscale":
img = img.convert("L")
decoded = Tensor(np.array(img))
return decoded.unsqueeze(-1) # (H, W) to (H, W, 1)
else:
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
_, _, *data_sz = safe_numpy(size).tolist()
size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz)
stackable = [original_grid]
for dim, dim_sz in enumerate(data_sz):
a = (
Tensor.arange(-1, 1.0001, 2 / (dim_sz - 1))
if align_corners == 1
else Tensor.arange(-1 + 1 / dim_sz, 1, 2 / dim_sz)
)
if dim == 0:
stackable = [
a.reshape(dim_sz, *[1] * (len(data_sz) - 1)) + size_zeros,
*stackable,
]
elif dim == 1:
stackable = [
a.reshape(1, dim_sz, *[1] * (len(data_sz) - 2)) + size_zeros,
*stackable,
]
else:
stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable]
original_grid = Tensor.stack(stackable, dim=len(data_sz))
if original_grid.ndim == 3:
N, dim_2d, dim_homo = theta.shape
assert dim_2d == 2 and dim_homo == 3
H, W, dim_homo = original_grid.shape
assert dim_homo == 3
original_grid = original_grid.reshape(H * W, dim_homo).transpose()
return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, H, W, dim_2d)
else:
assert original_grid.ndim == 4
N, dim_3d, dim_homo = theta.shape
assert dim_3d == 3 and dim_homo == 4
D, H, W, dim_homo = original_grid.shape
assert dim_homo == 4
original_grid = original_grid.reshape(D * H * W, dim_homo).transpose()
return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, D, H, W, dim_3d)
# **************** com.microsoft Ops ****************
def SkipLayerNormalization(
input: Tensor,
skip: Tensor,
gamma,
beta: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
epsilon=None,
):
if epsilon is None:
epsilon = 1e-12
x = input + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x: Tensor, bias: Optional[Tensor] = None):
x = x + bias
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x**3).tanh())
def EmbedLayerNormalization(
input_ids: Tensor,
segment_ids: Optional[Tensor] = None,
word_embedding: Tensor = None,
position_embedding: Tensor = None,
segment_embedding: Optional[Tensor] = None,
gamma=None,
beta=None,
mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
epsilon=None,
mask_index_type=None,
):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
assert (segment_ids is None) is (segment_embedding is None)
assert (mask is None) is (mask_index_type is None)
assert mask is None, "functionality not supported yet" # TODO
input_shape = input_ids.shape
bsz, seq_length = input_shape[0], input_shape[1]
compute_seg_emb = segment_embedding is not None and segment_ids is not None
vocab_size, max_position_embeddings, type_vocab_size = (
word_embedding.shape[0],
position_embedding.shape[0],
(segment_embedding.shape[0] if compute_seg_emb else None),
)
def embedding(
x: Tensor, vocab_size, weight: Tensor
) -> Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
vocab_counter = (
Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False)
.reshape(1, 1, vocab_size)
.expand(*x.shape, vocab_size)
)
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
# bert embedding layer
if epsilon is None:
epsilon = 1e-12
if position_ids is None:
position_ids = (
Tensor.arange(seq_length, requires_grad=False)
.unsqueeze(0)
.expand(*input_shape)
)
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
pos_embedding_res = embedding(
position_ids, max_position_embeddings, position_embedding
)
seg_embedding_res = (
embedding(segment_ids, type_vocab_size, segment_embedding)
if compute_seg_emb
else None
)
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def Attention(
input: Tensor,
weights,
bias: Optional[Tensor] = None,
mask_index: Optional[Tensor] = None,
past: Optional[Tensor] = None,
relative_position_bias: Optional[Tensor] = None,
past_sequence_length: Optional[Tensor] = None,
do_rotary=None,
mask_filter_value=None,
num_heads=None,
past_present_share_buffer=None,
qkv_hidden_sizes=None,
scale=None,
unidirectional=None,
):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (
qkv_hidden_sizes is not None
)
assert (
relative_position_bias
== do_rotary
== past_sequence_length
== mask_filter_value
== past_present_share_buffer
== scale
== None
), "functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = (
qkv_hidden_sizes[1:]
if qkv_hidden_sizes is not None
else 2 * (weights.shape[1] // 3,)
)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = input.linear(weights, bias)
xq, xk, xv = [
xqkv.slice([None, None, (i * hidden_size, (i + 1) * hidden_size)])
for i in range(3)
]
else: # bert-style
wq, wk, wv = (
weights[:, :hidden_size],
weights[:, hidden_size : hidden_size + v_hidden_size],
weights[:, hidden_size + v_hidden_size :],
)
bq, bk, bv = (
(
bias[:hidden_size],
bias[hidden_size : hidden_size + v_hidden_size],
bias[hidden_size + v_hidden_size],
)
if bias is not None
else None
)
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [
x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2)
for x in (xq, xk, xv)
]
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = (
Tensor.ones((cdim, cdim), requires_grad=False)
.cast(dtypes.bool)
.tril(0)[key_length - query_length : key_length, :key_length]
.cast(dtypes.bool)
)
return (
Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask
).softmax(-1) @ value
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present
# **************** ai.onnx.preview.training Ops ****************
# TODO not entirely sure these optimizers are correct
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
r = R / (1 + T * decay_factor)
ret = []
for input in grouped_inputs:
X, G, H = input
X.grad = norm_coefficient * X + G
X.grad.requires_grad, H.requires_grad = (
False,
False,
) # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
H.assign(H.detach() + X.grad * X.grad).realize()
H_adaptive = H.sqrt() + epsilon
X.assign(X.detach() - r * X.grad / H_adaptive)
ret.extend([X, H])
ret = ret[::2] + ret[1::2]
return tuple(ret)
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
groups = len(inputs) // 3
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
beta_adjusted = beta if T > 0 else 1
ret = []
for input in grouped_inputs:
X, G, V = input
X.grad = (norm_coefficient * X + G).realize()
X.grad.requires_grad, V.requires_grad = False, False
V.assign(alpha * V + beta_adjusted * X.grad).realize()
if mode == "standard":
X.assign(X.detach() - R * V).realize()
elif mode == "nesterov":
X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
ret.extend([X, V])
ret = ret[::2] + ret[1::2]
return tuple(ret)
# copied from tinygrad/nn/optim.py: LAMB with some edits
def Adam(
R,
T,
*inputs,
alpha=0.9,
beta=0.999,
epsilon=0.0,
norm_coefficient=0.0,
norm_coefficient_post=0.0,
):
groups = len(inputs) // 4
grouped_inputs = [inputs[i::groups] for i in range(groups)]
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
ret = []
for input in grouped_inputs:
X, G, V, H = input
X.grad = (norm_coefficient * X + G).realize()
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
up = (
(V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon)
if T > 0
else V / (H.sqrt() + epsilon)
)
X.assign(X.detach() - R * up).realize()
X = (1 - norm_coefficient_post) * X
ret.extend([X, V, H])
ret = ret[::3] + ret[1::3] + ret[2::3]
return tuple(ret)