1835 lines
59 KiB
Python
1835 lines
59 KiB
Python
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import prod, dtypes, ImageDType, flatten
|
|
from extra.onnx import safe_numpy
|
|
from onnx.helper import tensor_dtype_to_np_dtype
|
|
from onnx import TensorProto
|
|
import io
|
|
import os
|
|
import numpy as np
|
|
import functools
|
|
from typing import Union, Tuple, Optional, List, Any
|
|
import math
|
|
|
|
tensor_methods = {
|
|
"Neg",
|
|
"Reciprocal",
|
|
"Sqrt",
|
|
"Sign",
|
|
"Abs",
|
|
"Exp",
|
|
"Log",
|
|
"Mish",
|
|
"Sin",
|
|
"Cos",
|
|
"Tan",
|
|
"Relu",
|
|
"Sigmoid",
|
|
"Tanh",
|
|
"MatMul",
|
|
"Floor",
|
|
"Ceil",
|
|
"Tanh",
|
|
"Softplus",
|
|
"HardSwish",
|
|
"Where",
|
|
"Mul",
|
|
"Sinh",
|
|
"Cosh",
|
|
"Softsign",
|
|
"Asinh",
|
|
"Acosh",
|
|
"Atanh",
|
|
}
|
|
|
|
# **************** Free Ops ****************
|
|
|
|
|
|
def Identity(input: Tensor):
|
|
return input
|
|
|
|
|
|
def Add(input: Tensor, other: Tensor, broadcast=None):
|
|
return (
|
|
input + other
|
|
if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType)
|
|
else (input + other).cast(input.dtype)
|
|
)
|
|
|
|
|
|
def Sub(input: Union[Tensor, Any], other: Tensor):
|
|
return input - other # some test has input as int
|
|
|
|
|
|
def Div(input: Tensor, other: Tensor):
|
|
return (
|
|
input / other
|
|
if input.dtype == dtypes.float or isinstance(input.dtype, ImageDType)
|
|
else input.div(other).floor()
|
|
) # TODO: this has dtype issues
|
|
|
|
|
|
def Pow(input: Tensor, other: Tensor):
|
|
return (input.float() ** other.float()).cast(
|
|
input.dtype
|
|
) # TODO: this has dtype issues
|
|
|
|
|
|
def Less(x: Tensor, y: Tensor):
|
|
return (x < y).cast(dtypes.bool)
|
|
|
|
|
|
def LessOrEqual(x: Tensor, y: Tensor):
|
|
return (x <= y).cast(dtypes.bool)
|
|
|
|
|
|
def Greater(x: Tensor, y: Tensor):
|
|
return (x > y).cast(dtypes.bool)
|
|
|
|
|
|
def GreaterOrEqual(x: Tensor, y: Tensor):
|
|
return (x >= y).cast(dtypes.bool)
|
|
|
|
|
|
def Equal(x: Tensor, y: Tensor):
|
|
return (x == y).cast(dtypes.bool)
|
|
|
|
|
|
def Max(*data_0):
|
|
return functools.reduce(Tensor.maximum, data_0)
|
|
|
|
|
|
def Min(*data_0):
|
|
return functools.reduce(Tensor.minimum, data_0)
|
|
|
|
|
|
def Sum(*data_0):
|
|
return functools.reduce(Tensor.__add__, data_0)
|
|
|
|
|
|
def Mean(*data_0):
|
|
return functools.reduce(Tensor.__add__, data_0) / len(data_0)
|
|
|
|
|
|
def Cast(input: Tensor, to):
|
|
return input.cast(dtypes.from_np(tensor_dtype_to_np_dtype(to)))
|
|
|
|
|
|
# **************** Simple Ops ****************
|
|
|
|
|
|
def Constant(
|
|
value: Tensor = None,
|
|
value_float=None,
|
|
value_floats=None,
|
|
value_int=None,
|
|
value_ints=None,
|
|
value_string=None,
|
|
value_strings=None,
|
|
):
|
|
if value:
|
|
return value
|
|
elif value_float:
|
|
return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
|
elif value_floats:
|
|
return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
|
elif value_int:
|
|
return Tensor(value_int, dtype=dtypes.int64, requires_grad=False)
|
|
elif value_ints:
|
|
return Tensor(list(value_ints), dtype=dtypes.int64, requires_grad=False)
|
|
elif value_string or value_strings:
|
|
raise NotImplementedError(
|
|
f"value_string or value_strings not implemented for Constant op"
|
|
)
|
|
|
|
|
|
def HardSigmoid(input: Tensor, alpha=0.2, beta=0.5):
|
|
return (alpha * input + beta).clip(0, 1)
|
|
|
|
|
|
def Gelu(x: Tensor, approximate=None):
|
|
return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + Erf(x / math.sqrt(2)))
|
|
|
|
|
|
def Celu(x: Tensor, alpha=1.0):
|
|
return x.celu(alpha)
|
|
|
|
|
|
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875):
|
|
return gamma * (X.relu() - (-alpha * X.exp() + alpha).relu())
|
|
|
|
|
|
def PRelu(X: Tensor, slope: Tensor):
|
|
slope = (
|
|
slope[0] if slope.shape[-1] != X.shape[-1] else slope
|
|
) # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
|
return X.clip(0, float("inf")) + X.clip(float("-inf"), 0) * slope
|
|
|
|
|
|
def LeakyRelu(X: Tensor, alpha=0.01):
|
|
return X.leakyrelu(alpha)
|
|
|
|
|
|
def ThresholdedRelu(X: Tensor, alpha=1.0):
|
|
return (X - alpha).relu() + (X - alpha).relu().sign() * alpha
|
|
|
|
|
|
def Softmax_1(input: Tensor, axis=1):
|
|
return input.softmax(axis)
|
|
|
|
|
|
def Softmax_13(input: Tensor, axis=-1):
|
|
return input.softmax(axis)
|
|
|
|
|
|
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
|
|
|
|
|
def LogSoftmax(input: Tensor, axis=-1):
|
|
return input.log_softmax(axis)
|
|
|
|
|
|
def Clip(input: Tensor, min=None, max=None):
|
|
return input.clip(
|
|
float("-inf") if min is None else min, float("inf") if max is None else max
|
|
)
|
|
|
|
|
|
# NOTE ReduceProd would require a new llop
|
|
def _axes(axes, noop_with_empty_axes):
|
|
return (
|
|
[int(x) for x in safe_numpy(axes)]
|
|
if axes is not None and not (isinstance(axes, Tensor) and axes.shape == (0,))
|
|
else ([] if noop_with_empty_axes else None)
|
|
)
|
|
|
|
|
|
def ReduceMax(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceMin(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceMean(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceSumSquare(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceL1(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
|
|
|
|
|
def ReduceL2(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
|
|
|
|
|
|
def ReduceLogSum(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
|
|
|
|
|
def ReduceLogSumExp(data: Tensor, axes=None, keepdims=1, noop_with_empty_axes=0):
|
|
return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
|
|
|
|
|
def GlobalAveragePool(X: Tensor):
|
|
return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
|
|
|
|
|
def GlobalMaxPool(X: Tensor):
|
|
return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
|
|
|
|
|
def OptionalHasElement(x: Tensor = None):
|
|
return Tensor(x is not None and x.numel() > 0, dtype=dtypes.bool)
|
|
|
|
|
|
def OptionalGetElement(x: Tensor = None):
|
|
return x if x is not None else Tensor([], dtype=dtypes.float32)
|
|
|
|
|
|
def Tile(input: Tensor, repeats):
|
|
return input.repeat([int(x) for x in safe_numpy(repeats)])
|
|
|
|
|
|
def Range(start: Tensor, limit, delta):
|
|
return Tensor.arange(
|
|
start=int(safe_numpy(start).item()),
|
|
stop=int(safe_numpy(limit).item()),
|
|
step=int(safe_numpy(delta).item()),
|
|
).cast(dtype=start.dtype)
|
|
|
|
|
|
def Shape(data: Tensor, end=None, start=0):
|
|
return Tensor(
|
|
list(data.shape)[start:end],
|
|
dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64,
|
|
) # TODO: really?
|
|
|
|
|
|
def Size(data: Tensor):
|
|
return prod(data if isinstance(data, list) else data.shape)
|
|
|
|
|
|
def Flatten(input: Tensor, axis=1):
|
|
return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
|
|
|
|
|
def Reshape(data: Tensor, shape: Tensor, allowzero=None):
|
|
return data.reshape(
|
|
[int(x) if x != 0 else data.shape[i] for i, x in enumerate(safe_numpy(shape))]
|
|
)
|
|
|
|
|
|
def Shrink(input: Tensor, bias=0.0, lambd=0.5):
|
|
return (input < -lambd) * (input + bias) + (input > lambd) * (input - bias)
|
|
|
|
|
|
def And(x: Tensor, y: Tensor):
|
|
return (x == y).where(x, Tensor.zeros(*x.shape)).cast(dtypes.bool)
|
|
|
|
|
|
def Or(x: Tensor, y: Tensor):
|
|
return (x == y).where(x, Tensor.ones(*x.shape)).cast(dtypes.bool)
|
|
|
|
|
|
def Xor(x: Tensor, y: Tensor):
|
|
return (
|
|
(x == y).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
|
)
|
|
|
|
|
|
def Not(x: Tensor):
|
|
return (
|
|
(x == 1).where(Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
|
)
|
|
|
|
|
|
def Asin(x):
|
|
return Atan(x / Tensor.sqrt(1 - x * x))
|
|
|
|
|
|
def Acos(x: Tensor):
|
|
negate = x < 0
|
|
x = x.abs()
|
|
ret = (
|
|
(((-0.0187293 * x) + 0.0742610) * x - 0.2121144) * x + 1.5707288
|
|
) * Tensor.sqrt(1.0 - x)
|
|
ret = ret - 2 * negate * ret
|
|
return negate * 3.14159265358979 + ret
|
|
|
|
|
|
def Atan(y: Tensor):
|
|
x = Tensor.ones(y.shape)
|
|
t3 = x
|
|
t1 = y.abs()
|
|
t0 = (t3 > t1).where(t3, t1)
|
|
t1 = (t3 < t1).where(t3, t1)
|
|
t3 = t1 / t0
|
|
t4 = t3 * t3
|
|
t0 = (
|
|
(((-0.013480470 * t4 + 0.057477314) * t4 - 0.121239071) * t4 + 0.195635925) * t4
|
|
- 0.332994597
|
|
) * t4 + 0.999995630
|
|
t3 = t0 * t3
|
|
t3 = (y.abs() > x.abs()).where(1.570796327 - t3, t3)
|
|
return (y < 0).where(-t3, t3)
|
|
|
|
|
|
def Trilu(x: Tensor, k: Union[Tensor, int] = 0, upper=1):
|
|
k = (
|
|
int(k.numpy().item()) if k != 0 else 0
|
|
) # onnx passes k as a tensor int64 with one element, default is 0
|
|
return x.triu(k) if upper else x.tril(k)
|
|
|
|
|
|
def Squeeze(input: Tensor, axes):
|
|
if isinstance(axes, Tensor):
|
|
axes = safe_numpy(axes)
|
|
axes = [int(x) if x >= 0 else int(x + input.ndim) for x in axes]
|
|
return input.reshape([s for i, s in enumerate(input.shape) if i not in axes])
|
|
|
|
|
|
def Unsqueeze(data: Tensor, axes):
|
|
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
|
new_shape = [1] * (len(data.shape) + len(axes))
|
|
ptr = iter(data.shape)
|
|
for i in range(len(new_shape)):
|
|
if i not in axes:
|
|
new_shape[i] = next(ptr)
|
|
return data.reshape(new_shape)
|
|
|
|
|
|
def Binarizer(input, threshold=0.0):
|
|
return input > threshold
|
|
|
|
|
|
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
|
axis = axis + x.ndim if axis < 0 else axis
|
|
m = x == (
|
|
x.max(axis=axis, keepdim=keepdims)
|
|
if keepdims
|
|
else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis)
|
|
)
|
|
c = (
|
|
Tensor.arange(x.shape[axis]).reshape(
|
|
*[1] * (axis), x.shape[axis], *[1] * (x.ndim - axis - 1)
|
|
)
|
|
* m
|
|
)
|
|
return c.max(axis=axis, keepdim=keepdims).cast(dtypes.int64)
|
|
|
|
|
|
def ArgMin(x, axis=0, keepdims=1, select_last_index=0):
|
|
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
|
|
|
|
|
def Elu(input: Tensor, alpha=1.0):
|
|
return input.elu(alpha=alpha)
|
|
|
|
|
|
def Concat(*inputs: List[Tensor], axis):
|
|
return inputs[0].cat(*inputs[1:], dim=axis)
|
|
|
|
|
|
def Transpose(input: Tensor, perm=None):
|
|
return input.permute(
|
|
order=list(range(len(input.shape))[::-1]) if perm is None else perm
|
|
)
|
|
|
|
|
|
# NOTE: since we only have one type, this is valid!
|
|
def CastLike(input, target_type):
|
|
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
|
|
return input
|
|
|
|
|
|
def ConstantOfShape(input, value: Tensor = None):
|
|
if value is None:
|
|
value = Tensor([0.0])
|
|
shape = [int(x) for x in safe_numpy(input)]
|
|
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0] != 0 else 1)
|
|
|
|
|
|
# TODO: abstract out the broadcast logic in tensor
|
|
def Expand(input: Tensor, shape):
|
|
x_shape, y_shape = input.shape, [int(x) for x in safe_numpy(shape)]
|
|
# copied from _broadcasted
|
|
x_shape, y_shape = [
|
|
([1] * (max(len(x_shape), len(y_shape)) - len(t_shape)) + list(t_shape))
|
|
for t_shape in [x_shape, y_shape]
|
|
]
|
|
shape_ret = tuple(max(sx, sy) for sx, sy in zip(x_shape, y_shape))
|
|
return input.reshape(x_shape).expand(shape_ret)
|
|
|
|
|
|
# **************** Complex Ops ****************
|
|
|
|
|
|
def Gemm(
|
|
A: Tensor,
|
|
B: Tensor,
|
|
C: Tensor = None,
|
|
alpha=1.0,
|
|
beta=1.0,
|
|
transA=0,
|
|
transB=0,
|
|
broadcast=0,
|
|
):
|
|
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
|
if C is not None:
|
|
ret += beta * (
|
|
C
|
|
if broadcast == 0
|
|
else C.reshape(
|
|
[-1 if i < len(C.shape) else 1 for i in range(len(ret.shape))][::-1]
|
|
)
|
|
)
|
|
return ret
|
|
|
|
|
|
# works with Tensors.ndim != 4
|
|
def _batchnorm(
|
|
self: Tensor,
|
|
weight: Optional[Tensor],
|
|
bias: Optional[Tensor],
|
|
mean: Tensor,
|
|
invstd: Tensor,
|
|
):
|
|
shape = [1, -1] + [1] * (self.ndim - 2)
|
|
x = self - mean.reshape(shape=shape)
|
|
if weight:
|
|
x = x * weight.reshape(shape=shape)
|
|
ret = x.mul(invstd.reshape(shape=shape) if len(invstd.shape) == 1 else invstd)
|
|
return (ret + bias.reshape(shape=shape)) if bias else ret
|
|
|
|
|
|
# TODO: this is copied from tinygrad/nn/__init__.py
|
|
# spatial is from opset 7 and has since been removed
|
|
def BatchNormalization(
|
|
X: Tensor,
|
|
scale,
|
|
B,
|
|
input_mean,
|
|
input_var,
|
|
epsilon=1e-05,
|
|
momentum=0.9,
|
|
training_mode=0,
|
|
spatial=1,
|
|
is_test=0,
|
|
):
|
|
if training_mode:
|
|
x_detached = X.detach()
|
|
current_mean = x_detached.mean(axis=(0, 2, 3))
|
|
y = x_detached - current_mean.reshape(shape=[1, -1, 1, 1])
|
|
current_var = (y * y).mean(axis=(0, 2, 3))
|
|
current_invstd = current_var.add(epsilon).pow(-0.5)
|
|
|
|
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
|
running_var = input_var * momentum + current_var * (1 - momentum)
|
|
|
|
return (
|
|
_batchnorm(X, scale, B, current_mean, current_invstd),
|
|
running_mean,
|
|
running_var,
|
|
)
|
|
else:
|
|
invstd = (input_var + epsilon) ** -0.5
|
|
return _batchnorm(X, scale, B, input_mean, invstd)
|
|
|
|
|
|
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
|
axis = tuple(range(2, len(x.shape)))
|
|
mean = x.mean(axis=axis, keepdim=True)
|
|
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
|
|
return (
|
|
x.sub(mean)
|
|
.mul(scale.reshape(shape=[-1, 1, 1]))
|
|
.mul(invstd)
|
|
.add(bias.reshape(shape=[-1, 1, 1]))
|
|
)
|
|
|
|
|
|
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
|
assert stash_type == 1, "only float32 is supported"
|
|
axis = tuple(
|
|
i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape))
|
|
)
|
|
mean = x.mean(axis=axis, keepdim=True)
|
|
return (
|
|
x.layernorm(axis, epsilon).mul(scale).add(bias),
|
|
mean,
|
|
(x.sub(mean))
|
|
.pow(2)
|
|
.mean(axis=axis, keepdim=True)
|
|
.add(epsilon)
|
|
.sqrt()
|
|
.reciprocal(),
|
|
)
|
|
|
|
|
|
def GroupNormalization(
|
|
x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05
|
|
):
|
|
return (
|
|
x.reshape(x.shape[0], num_groups, -1)
|
|
.layernorm(axis=-1, eps=epsilon)
|
|
.mul(scale.unsqueeze(-1))
|
|
.add(bias.unsqueeze(-1))
|
|
.reshape(x.shape)
|
|
)
|
|
|
|
|
|
# onnx: [x1_begin, x2_begin, ..., x1_end, x2_end, ...]
|
|
# numpy.pad: ((x1_begin, x1_end), (x2_begin, x2_end), ...)
|
|
def _format_padding(onnx_pads, ndims=None, axes=None):
|
|
if ndims and len(onnx_pads) // 2 != ndims:
|
|
onnx_pads = (
|
|
onnx_pads * ndims
|
|
) # for OnnxBackendPyTorchConvertedModelTest the len(onnx_pads) == 2
|
|
if ndims is None:
|
|
ndims = len(onnx_pads) // 2
|
|
if axes is None:
|
|
axes = list(range(ndims))
|
|
num_axes = len(axes)
|
|
np_pads = [(0, 0)] * ndims
|
|
for i in range(num_axes):
|
|
np_pads[axes[i]] = (onnx_pads[i], onnx_pads[i + num_axes])
|
|
return np_pads
|
|
|
|
|
|
def _padding(
|
|
X: Tensor,
|
|
pads=None,
|
|
auto_pad="NOTSET",
|
|
axes=None,
|
|
constant_value=0.0,
|
|
strides=None,
|
|
kernel_shape=None,
|
|
dilations=None,
|
|
ceil_mode=0,
|
|
):
|
|
if auto_pad != "NOTSET":
|
|
pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
|
elif ceil_mode and auto_pad == "NOTSET": # stupid ceil_mode case
|
|
if strides is not None:
|
|
strides = (
|
|
[strides] * len(kernel_shape)
|
|
if isinstance(strides, int)
|
|
else strides
|
|
if strides
|
|
else [1] * len(kernel_shape)
|
|
)
|
|
if dilations is not None:
|
|
dilations = [1] * len(kernel_shape) if dilations == 1 else dilations
|
|
out_spatial_shape = [
|
|
math.ceil((sh - dil * (ker - 1) - 1) / st + 1)
|
|
if ceil_mode
|
|
else math.floor((sh - dil * (ker - 1) - 1) / st + 1)
|
|
for sh, st, ker, dil in zip(
|
|
X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations
|
|
)
|
|
]
|
|
pad_shape = [
|
|
(osh - 1) * st + ((ks - 1) * dil + 1) - ish
|
|
for osh, st, ks, dil, ish in zip(
|
|
out_spatial_shape,
|
|
strides,
|
|
kernel_shape,
|
|
dilations,
|
|
X.shape[-len(kernel_shape) :],
|
|
)
|
|
]
|
|
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape])
|
|
pads = pad_shape[::2] + pad_shape[1::2]
|
|
if pads is None:
|
|
return X
|
|
pads = _format_padding(pads, ndims=len(X.shape), axes=axes)
|
|
return X.pad(tuple(pads), value=constant_value)
|
|
|
|
|
|
def _auto_pad(X: Tensor, auto_pad, strides, kernel_shape, dilations):
|
|
strides = (
|
|
[strides] * len(kernel_shape)
|
|
if isinstance(strides, int)
|
|
else strides
|
|
if strides
|
|
else [1] * len(kernel_shape)
|
|
)
|
|
dilations = [1] * len(kernel_shape) if dilations == 1 else dilations
|
|
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
|
|
pad_shape = [
|
|
(math.ceil(sh / st) - 1) * st + ((ks - 1) * di + 1) - sh
|
|
for sh, st, ks, di in zip(
|
|
X.shape[-len(kernel_shape) :], strides, kernel_shape, dilations
|
|
)
|
|
]
|
|
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in pad_shape])
|
|
return (
|
|
pad_shape[::2] + pad_shape[1::2]
|
|
if auto_pad == "SAME_UPPER"
|
|
else pad_shape[1::2] + pad_shape[::2]
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"auto_pad={auto_pad} not implemented")
|
|
|
|
|
|
def Pad(
|
|
x: Tensor,
|
|
pads: Union[Tensor, Tuple[int, ...]],
|
|
constant_value: Tensor = None,
|
|
axes: Tensor = None,
|
|
mode="constant",
|
|
value: float = 0.0,
|
|
):
|
|
constant_value = (
|
|
value if constant_value is None else float(safe_numpy(constant_value)[0])
|
|
)
|
|
seq_pads = list(pads) if isinstance(pads, tuple) else safe_numpy(pads)
|
|
seq_pads = [math.ceil(i) for i in seq_pads]
|
|
seq_axes = safe_numpy(axes).astype(np.int32).tolist() if axes is not None else None
|
|
base_shape = x.shape
|
|
pads = _format_padding(seq_pads, ndims=len(x.shape), axes=seq_axes)
|
|
if mode == "wrap":
|
|
repeat_args = [
|
|
math.ceil(dim[0] / sh) + math.ceil(dim[1] / sh) + 1
|
|
for dim, sh in zip(pads, base_shape)
|
|
]
|
|
new_shape = [s * r for s, r in zip(base_shape, repeat_args)]
|
|
shrink_args = [
|
|
(
|
|
sh - dim[0] % sh if dim[0] % sh != 0 else 0,
|
|
nsh - (sh - dim[1] % sh) if dim[1] % sh != 0 else nsh,
|
|
)
|
|
for dim, sh, nsh in zip(pads, base_shape, new_shape)
|
|
]
|
|
return x.repeat(tuple(repeat_args)).shrink(tuple(shrink_args))
|
|
elif mode == "reflect":
|
|
for i, s in enumerate(x.shape):
|
|
if pads[i] == (0, 0):
|
|
continue
|
|
elif pads[i][0] and not pads[i][1]:
|
|
x = x.flip(i).shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
).pad(
|
|
tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)])
|
|
) + x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
elif not pads[i][0] and pads[i][1]:
|
|
x = x.flip(i).shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (1, pads[i][1] + 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
).pad(
|
|
tuple([(0, 0) if i_ != i else (s, 0) for i_ in range(x.ndim)])
|
|
) + x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
else:
|
|
x = (
|
|
x.flip(i)
|
|
.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (s - pads[i][0] - 1, s_ - 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
)
|
|
.pad(
|
|
tuple(
|
|
[
|
|
(0, 0) if i_ != i else (0, s + pads[i][1])
|
|
for i_ in range(x.ndim)
|
|
]
|
|
)
|
|
)
|
|
+ x.flip(i)
|
|
.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (1, pads[i][1] + 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
)
|
|
.pad(
|
|
tuple(
|
|
[
|
|
(0, 0) if i_ != i else (s + pads[i][0], 0)
|
|
for i_ in range(x.ndim)
|
|
]
|
|
)
|
|
)
|
|
+ x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
)
|
|
return x
|
|
elif mode == "edge":
|
|
for i, s in enumerate(x.shape):
|
|
if pads[i] == (0, 0):
|
|
continue
|
|
elif pads[i][0] and not pads[i][1]:
|
|
x = x.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (0, 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
).expand(
|
|
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
|
|
).pad(
|
|
tuple([(0, 0) if i_ != i else (0, s) for i_ in range(x.ndim)])
|
|
) + x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
elif not pads[i][0] and pads[i][1]:
|
|
x = x.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (s_ - 1, s_)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
).expand(
|
|
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
|
|
).pad(
|
|
tuple(
|
|
[
|
|
(0, 0) if i_ != i else (s + pads[i][0], 0)
|
|
for i_ in range(x.ndim)
|
|
]
|
|
)
|
|
) + x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
else:
|
|
x = (
|
|
x.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (0, 1)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
)
|
|
.expand(
|
|
[pads[i][0] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
|
|
)
|
|
.pad(
|
|
tuple(
|
|
[
|
|
(0, 0) if i_ != i else (0, s + pads[i][1])
|
|
for i_ in range(x.ndim)
|
|
]
|
|
)
|
|
)
|
|
+ x.shrink(
|
|
tuple(
|
|
[
|
|
(0, s_) if i_ != i else (s_ - 1, s_)
|
|
for i_, s_ in enumerate(x.shape)
|
|
]
|
|
)
|
|
)
|
|
.expand(
|
|
[pads[i][1] if i_ == i else s_ for i_, s_ in enumerate(x.shape)]
|
|
)
|
|
.pad(
|
|
tuple(
|
|
[
|
|
(0, 0) if i_ != i else (s + pads[i][0], 0)
|
|
for i_ in range(x.ndim)
|
|
]
|
|
)
|
|
)
|
|
+ x.pad(
|
|
tuple([(0, 0) if i_ != i else pads[i] for i_ in range(x.ndim)])
|
|
)
|
|
)
|
|
return x
|
|
elif mode == "constant":
|
|
return _padding(x, seq_pads, axes=seq_axes, constant_value=constant_value)
|
|
|
|
|
|
def AveragePool(
|
|
X: Tensor,
|
|
kernel_shape,
|
|
auto_pad="NOTSET",
|
|
ceil_mode=0,
|
|
count_include_pad=0,
|
|
dilations=1,
|
|
pads=None,
|
|
strides=1,
|
|
):
|
|
pixel_axes = tuple(range(len(X.shape)))[2:]
|
|
ret = _padding(
|
|
X,
|
|
pads,
|
|
auto_pad,
|
|
axes=pixel_axes,
|
|
strides=strides,
|
|
kernel_shape=kernel_shape,
|
|
dilations=dilations,
|
|
ceil_mode=ceil_mode,
|
|
).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
|
if count_include_pad:
|
|
return ret
|
|
else:
|
|
div = _padding(
|
|
Tensor.ones(*X.shape),
|
|
pads,
|
|
auto_pad,
|
|
axes=pixel_axes,
|
|
strides=strides,
|
|
kernel_shape=kernel_shape,
|
|
dilations=dilations,
|
|
ceil_mode=ceil_mode,
|
|
).avg_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
|
return ret / div
|
|
|
|
|
|
def MaxPool(
|
|
X: Tensor,
|
|
kernel_shape,
|
|
auto_pad="NOTSET",
|
|
ceil_mode=0,
|
|
dilations=1,
|
|
pads=None,
|
|
storage_order=0,
|
|
strides=1,
|
|
):
|
|
ret = _padding(
|
|
X,
|
|
pads,
|
|
auto_pad,
|
|
constant_value=float("-inf"),
|
|
axes=tuple(range(len(X.shape)))[2:],
|
|
strides=strides,
|
|
kernel_shape=kernel_shape,
|
|
dilations=dilations,
|
|
ceil_mode=ceil_mode,
|
|
)
|
|
ret = ret.max_pool2d(kernel_shape, stride=strides, dilation=dilations)
|
|
ret_len, X_len = ret.numel(), X.numel()
|
|
indices = (
|
|
(
|
|
(
|
|
ret.flatten().unsqueeze(1).expand(ret_len, X_len)
|
|
== X.flatten().reshape(1, X_len).expand(ret_len, X_len)
|
|
)
|
|
* Tensor.arange(X_len).reshape(1, X_len).expand(ret_len, X_len)
|
|
)
|
|
.sum(1)
|
|
.reshape(ret.shape)
|
|
.cast(dtypes.int64)
|
|
)
|
|
if storage_order:
|
|
indices = indices.transpose(indices.ndim - 2, indices.ndim - 1)
|
|
return ret, indices
|
|
|
|
|
|
def MaxUnpool(
|
|
xT: Tensor,
|
|
xI: Tensor,
|
|
outshape: Tensor = None,
|
|
kernel_shape=None,
|
|
pads=None,
|
|
strides=None,
|
|
):
|
|
out_sh = [
|
|
(ks // 2) * 2 + st * inps
|
|
for inps, st, ks in zip(xI.shape, strides, kernel_shape)
|
|
]
|
|
outlength = prod(out_sh)
|
|
xI = xI.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
|
arange = (
|
|
Tensor.arange(outlength, requires_grad=False)
|
|
.reshape(1, outlength)
|
|
.expand(xI.shape)
|
|
)
|
|
xT = xT.flatten().unsqueeze(1).expand(prod(xT.shape), outlength)
|
|
ret = ((xI == arange) * xT).sum(0).reshape([1, 1] + out_sh)
|
|
if outshape is not None:
|
|
outshape = safe_numpy(outshape).tolist()
|
|
if outshape != ret.shape:
|
|
diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]]
|
|
pad_args = [
|
|
diff[0] // 2,
|
|
diff[1] // 2,
|
|
diff[0] - diff[0] // 2,
|
|
diff[1] - diff[1] // 2,
|
|
]
|
|
ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2]))
|
|
return ret
|
|
|
|
|
|
def Conv(
|
|
X: Tensor,
|
|
W: Tensor,
|
|
B=None,
|
|
auto_pad="NOTSET",
|
|
dilations=1,
|
|
group=1,
|
|
kernel_shape=None,
|
|
pads=None,
|
|
strides=1,
|
|
):
|
|
if auto_pad != "NOTSET":
|
|
padding = _auto_pad(X, auto_pad, strides, kernel_shape, dilations)
|
|
else:
|
|
padding = (
|
|
[
|
|
p
|
|
for ps in zip(
|
|
pads[: len(pads) // 2][::-1], pads[len(pads) // 2 :][::-1]
|
|
)
|
|
for p in ps
|
|
]
|
|
if pads is not None
|
|
else 0
|
|
) # reorder padding
|
|
return X.conv2d(
|
|
W, B, stride=strides, groups=group, dilation=dilations, padding=padding
|
|
)
|
|
|
|
|
|
def ConvTranspose(
|
|
X: Tensor,
|
|
W: Tensor,
|
|
B=None,
|
|
auto_pad="NOTSET",
|
|
dilations=1,
|
|
group=1,
|
|
kernel_shape=None,
|
|
pads=None,
|
|
output_shape=None,
|
|
output_padding=0,
|
|
strides=1,
|
|
):
|
|
if kernel_shape is None:
|
|
kernel_shape = W.shape[2:]
|
|
if isinstance(strides, int):
|
|
strides = [strides] * (W.ndim - 2)
|
|
if isinstance(dilations, int):
|
|
dilations = [dilations] * (W.ndim - 2)
|
|
if isinstance(output_padding, int):
|
|
output_padding = [output_padding] * (W.ndim - 2)
|
|
out_sh = (
|
|
[
|
|
st * (xs - 1) + (ks - 1) * di + 1
|
|
if n < 2
|
|
else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1]
|
|
for n, (st, xs, ks, di) in enumerate(
|
|
zip(strides, X.shape[2:], kernel_shape, dilations)
|
|
)
|
|
]
|
|
if output_shape is not None or auto_pad != "NOTSET"
|
|
else []
|
|
)
|
|
if pads is None:
|
|
if output_shape is None:
|
|
output_shape = [xs * st for xs, st in zip(X.shape[2:], strides)]
|
|
if auto_pad == "NOTSET":
|
|
pads = [0, 0] * (X.ndim - 2)
|
|
else:
|
|
total_padding = [
|
|
st * (ish - 1) + pad + ((ks - 1) * dil + 1) - osh
|
|
for st, ish, pad, ks, dil, osh in zip(
|
|
strides,
|
|
X.shape[2:],
|
|
output_padding,
|
|
kernel_shape,
|
|
dilations,
|
|
output_shape,
|
|
)
|
|
]
|
|
pad_shape = flatten([[sh // 2, sh - sh // 2] for sh in total_padding])
|
|
pads = (
|
|
pad_shape[::2] + pad_shape[1::2]
|
|
if auto_pad == "SAME_UPPER"
|
|
else pad_shape[1::2] + pad_shape[::2]
|
|
)
|
|
else:
|
|
if output_shape is None:
|
|
output_shape = [
|
|
st * (xs - 1) + (ks - 1) * di + 1
|
|
if n < 2
|
|
else st * (xs - 1) + (ks - 1) * di + 1 - pads[n - 2] - pads[n - 1]
|
|
for n, (st, xs, ks, di) in enumerate(
|
|
zip(strides, X.shape[2:], kernel_shape, dilations)
|
|
)
|
|
]
|
|
if out_sh:
|
|
output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
|
|
return X.conv_transpose2d(
|
|
W,
|
|
B,
|
|
stride=strides,
|
|
groups=group,
|
|
dilation=dilations,
|
|
padding=pads if pads is not None else 0,
|
|
output_padding=output_padding,
|
|
)
|
|
|
|
|
|
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
|
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
|
|
if isinstance(ratio, Tensor) and not ratio.shape:
|
|
ratio = safe_numpy(
|
|
ratio
|
|
) # ratio and tensor is passed in as Tensor with shape: ()
|
|
if isinstance(training_mode, Tensor) and not training_mode.shape:
|
|
training_mode = safe_numpy(training_mode)
|
|
if not training_mode:
|
|
return data, Tensor.ones(
|
|
*data.shape, dtype=dtypes.bool
|
|
) # if mask is requested as output it will contain all True's.
|
|
rng = np.random.RandomState(seed)
|
|
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
|
|
mask = Tensor(
|
|
(rng.random(data.shape) >= ratio), requires_grad=False, device=data.device
|
|
)
|
|
return data * mask * (1 / (1.0 - ratio)), mask
|
|
|
|
|
|
def LRN(input: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0):
|
|
bs, c, iy, ix = input.shape
|
|
return input / input.mul(input).reshape(bs, 1, c, iy * ix).pad2d(
|
|
(0, 0, (size - 1) // 2, size // 2)
|
|
).avg_pool2d((size, 1), 1).reshape(bs, c, iy, ix).mul(alpha).add(bias).pow(beta)
|
|
|
|
|
|
def MeanVarianceNormalization(input: Tensor, axis=(0, 2, 3)):
|
|
data_mean = input.mean(axis=axis, keepdim=True)
|
|
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
|
|
return (input - data_mean) / (std + 1e-9)
|
|
|
|
|
|
def NegativeLogLikelihoodLoss(
|
|
input: Tensor, target: Tensor, weight=None, ignore_index=None, reduction="mean"
|
|
):
|
|
target = target.cast(dtypes.float32)
|
|
N, C, i_shape = input.shape[0], input.shape[1], input.shape
|
|
t_shape = target.shape
|
|
if len(input.shape) != 3:
|
|
input = input.reshape((N, C, -1))
|
|
target = target.reshape((N, -1))
|
|
if weight is not None:
|
|
mask = target.unsqueeze(-1) == Tensor.arange(C).repeat((N, 1, 1))
|
|
weight = (mask * weight).sum(axis=-1)
|
|
if ignore_index is not None:
|
|
cond = target == ignore_index
|
|
weight = (
|
|
cond.where(0, weight)
|
|
if weight is not None
|
|
else cond.where(Tensor.zeros(*target.shape), 1)
|
|
)
|
|
mask = target[:, None, :] == Tensor.arange(C).reshape(
|
|
[1, C] + [1] * (len(input.shape) - 2)
|
|
)
|
|
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
|
|
if reduction == "mean":
|
|
return loss.mean() if weight is None else loss.sum() / weight.sum()
|
|
elif reduction == "sum":
|
|
return loss.sum()
|
|
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
|
|
|
|
|
def SoftmaxCrossEntropyLoss(
|
|
scores: Tensor, labels: Tensor, weights=None, ignore_index=None, reduction="mean"
|
|
):
|
|
N, C, *s_dimensions = scores.shape
|
|
if ignore_index is not None:
|
|
labels = (labels == ignore_index).where(C + 1, labels)
|
|
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(
|
|
1, C, *[1] * len(s_dimensions)
|
|
)
|
|
y = scores.log_softmax(axis=1)
|
|
if weights is not None:
|
|
weights = weights.__getitem__(
|
|
tuple([labels, *[slice(None)] * (weights.ndim - 1)])
|
|
)
|
|
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
|
|
if reduction == "mean":
|
|
loss = (
|
|
loss.sum() / (loss == 0).where(0, 1).sum()
|
|
if weights is None
|
|
else loss.sum() / weights.sum()
|
|
)
|
|
elif reduction == "sum":
|
|
loss = loss.sum()
|
|
return loss, y
|
|
|
|
|
|
def ArrayFeatureExtractor(input: Tensor, indices: Tensor):
|
|
return input.__getitem__(
|
|
tuple(
|
|
[
|
|
slice(None) if i != (input.ndim - 1) else indices
|
|
for i in range(input.ndim)
|
|
]
|
|
)
|
|
)
|
|
|
|
|
|
def Gather(input: Tensor, indices: Tensor, axis=0):
|
|
if (
|
|
indices.numel() < 9
|
|
): # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
|
input_sh = list(input.shape)
|
|
ret_shape = input_sh[:axis] + list(indices.shape) + input_sh[axis + 1 :]
|
|
if indices.ndim > 1:
|
|
indices = indices.flatten()
|
|
indices = (
|
|
[int(safe_numpy(indices))]
|
|
if indices.shape == ()
|
|
else [
|
|
input_sh[axis] + int(x) if x < 0 else int(x)
|
|
for x in safe_numpy(indices)
|
|
]
|
|
)
|
|
args = [
|
|
[(0, x) if j != axis else (i, i + 1) for j, x in enumerate(input_sh)]
|
|
for i in indices
|
|
]
|
|
return (
|
|
input.shrink(arg=tuple(args[0]))
|
|
.cat(*[input.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis)
|
|
.reshape(ret_shape)
|
|
)
|
|
else: # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
|
return input.__getitem__(
|
|
tuple([slice(None) if i != axis else indices for i in range(input.ndim)])
|
|
)
|
|
|
|
|
|
def GatherElements(input: Tensor, indices: Tensor, axis):
|
|
indices = (
|
|
indices.sign().contiguous().__neg__().contiguous().relu() * input.shape[axis]
|
|
+ indices
|
|
)
|
|
return input.gather(indices, axis)
|
|
|
|
|
|
def _round(x: Tensor, n: float, equidistant_case="round_down") -> Tensor:
|
|
def _and(cond1, cond2):
|
|
return ((cond1 + cond2) == 2).where(1, 0)
|
|
|
|
assert n <= 1, f"n:{n} shouldn't be larger than 1"
|
|
b = x.cast(dtypes.int32).contiguous().cast(x.dtype)
|
|
b = (b >= 0).where(b + n, b - n)
|
|
if equidistant_case == "round_down":
|
|
return (x > b).where(b + 1 - n, b - n)
|
|
elif equidistant_case == "round_up":
|
|
return (x >= b).where(b + 1 - n, b - n)
|
|
elif equidistant_case == "round_to_even":
|
|
x_ceil_fraction = x.ceil() / 2
|
|
cond_ceil_even = x_ceil_fraction.ceil() == x_ceil_fraction
|
|
x = (_and(x == b, cond_ceil_even)).where(x + 1 - n, x)
|
|
x = (x > b).where(b + 1 - n, b - n)
|
|
return x
|
|
|
|
|
|
def Round(X: Tensor):
|
|
return _round(X, 0.5, "round_to_even")
|
|
|
|
|
|
# TODO clean this up, it's taking the longest in CI
|
|
def Resize(
|
|
X: Tensor,
|
|
roi=None,
|
|
scales=None,
|
|
sizes=None,
|
|
antialias=0,
|
|
axes=None,
|
|
coordinate_transformation_mode="half_pixel",
|
|
cubic_coeff_a=-0.75,
|
|
exclude_outside=0,
|
|
extrapolation_value=0.0,
|
|
keep_aspect_ratio_policy="stretch",
|
|
mode="nearest",
|
|
nearest_mode="round_prefer_floor",
|
|
):
|
|
def _nearest_gather(X: Tensor, x_out, y_out):
|
|
return X[:, :, y_out, :][:, :, :, x_out]
|
|
|
|
def _nearest_mode(x_resized: Tensor, nearest_mode: str, x_len):
|
|
if nearest_mode == "round_prefer_floor":
|
|
ret = _round(x_resized, 0.5, "round_down")
|
|
elif nearest_mode == "round_prefer_ceil":
|
|
ret = _round(x_resized, 0.5, "round_up")
|
|
elif nearest_mode == "floor":
|
|
ret = x_resized.floor()
|
|
elif nearest_mode == "ceil":
|
|
ret = x_resized.ceil()
|
|
return ret.clip(0, x_len - 1)
|
|
|
|
def _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi=None):
|
|
if coordinate_transformation_mode == "half_pixel":
|
|
x_out = (x_out + 0.5) / Tensor(
|
|
scales_lol[-1]
|
|
) - 0.5 # TODO Tensor() because try (((Tensor([0,1,2,3,4,5])+0.5)/3.5 - 0.5)) with LLVM or METAL, inaccuacy.
|
|
y_out = (y_out + 0.5) / Tensor(scales_lol[-2]) - 0.5
|
|
elif coordinate_transformation_mode == "align_corners":
|
|
x_out = x_out * (X.shape[-1] - 1) / (output_shape[-1] - 1)
|
|
y_out = y_out * (X.shape[-2] - 1) / (output_shape[-2] - 1)
|
|
elif coordinate_transformation_mode == "asymmetric":
|
|
x_out = x_out / scales_lol[-1]
|
|
y_out = y_out / scales_lol[-2]
|
|
elif coordinate_transformation_mode == "half_pixel_symmetric":
|
|
x_out = (
|
|
X.shape[-1] / 2 * (1 - int(output_shape[-1]) / output_shape[-1])
|
|
+ (x_out + 0.5) / scales_lol[-1]
|
|
- 0.5
|
|
)
|
|
y_out = (
|
|
X.shape[-2] / 2 * (1 - int(output_shape[-2]) / output_shape[-2])
|
|
+ (y_out + 0.5) / scales_lol[-2]
|
|
- 0.5
|
|
)
|
|
elif coordinate_transformation_mode == "pytorch_half_pixel":
|
|
x_out = (
|
|
(x_out + 0.5) / scales_lol[-1] - 0.5
|
|
if output_shape[-1] > 1
|
|
else Tensor([0])
|
|
)
|
|
y_out = (
|
|
(y_out + 0.5) / scales_lol[-2] - 0.5
|
|
if output_shape[-2] > 1
|
|
else Tensor([0])
|
|
)
|
|
elif coordinate_transformation_mode == "tf_crop_and_resize":
|
|
x_out = (
|
|
roi[-1][0] * (X.shape[-1] - 1)
|
|
+ x_out
|
|
* (
|
|
(roi[-1][1] - roi[-1][0])
|
|
* (X.shape[-1] - 1)
|
|
/ (output_shape[-1] - 1)
|
|
)
|
|
if output_shape[-1] > 1
|
|
else Tensor([0.5 * (roi[-1][0] + roi[-1][1]) * (X.shape[-1] - 1)])
|
|
)
|
|
y_out = (
|
|
roi[-2][0] * (X.shape[-2] - 1)
|
|
+ y_out
|
|
* (
|
|
(roi[-2][1] - roi[-2][0])
|
|
* (X.shape[-2] - 1)
|
|
/ (output_shape[-2] - 1)
|
|
)
|
|
if output_shape[-2] > 1
|
|
else Tensor([0.5 * (roi[-2][0] + roi[-2][1]) * (X.shape[-2] - 1)])
|
|
)
|
|
return x_out.clip(0, X.shape[-1] - 1), y_out.clip(0, X.shape[-2] - 1)
|
|
|
|
if roi is not None:
|
|
roi = safe_numpy(roi)
|
|
roi = [(st, ed) for st, ed in zip(roi[: len(roi) // 2], roi[len(roi) // 2 :])]
|
|
roi_ = [(1, 1)] * 4
|
|
if axes is not None:
|
|
for a, r in zip(axes, roi):
|
|
roi_[a] = r
|
|
roi = roi_
|
|
if scales is not None:
|
|
scales = safe_numpy(scales).tolist()
|
|
if axes is not None:
|
|
scales_ = [1] * X.ndim
|
|
for a, s in zip(axes, scales):
|
|
scales_[a] = s
|
|
scales = scales_
|
|
elif sizes is not None:
|
|
sizes = [int(i) for i in safe_numpy(sizes)]
|
|
scales = []
|
|
if axes is not None:
|
|
sizes_ = [1] * X.ndim
|
|
for a, s in zip(axes, sizes):
|
|
sizes_[a] = s
|
|
scales.append(s / X.shape[a])
|
|
sizes = sizes_
|
|
else:
|
|
scales = [si / xs for xs, si in zip(X.shape, sizes)]
|
|
if keep_aspect_ratio_policy == "not_larger":
|
|
scale = min(scales)
|
|
sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up")
|
|
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
|
elif keep_aspect_ratio_policy == "not_smaller":
|
|
scale = max(scales)
|
|
sizes = _round(Tensor(list(X.shape[-2:])) * scale, 0.5, "round_up")
|
|
sizes = list(X.shape[:-2]) + [int(i) for i in safe_numpy(sizes)]
|
|
output_shape = (
|
|
sizes if sizes else [math.floor(x * s) for x, s in zip(X.shape, scales)]
|
|
)
|
|
output_shape_ = sizes if sizes else [x * s for x, s in zip(X.shape, scales)]
|
|
scales_lol = [os / xs for xs, os in zip(X.shape, output_shape)]
|
|
x_out = Tensor.arange(output_shape[-1])
|
|
y_out = Tensor.arange(output_shape[-2])
|
|
if mode == "nearest":
|
|
x_out, y_out = _coordinate_transformation(
|
|
x_out, y_out, output_shape, scales_lol, roi
|
|
)
|
|
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
|
y_out = _nearest_mode(y_out, nearest_mode, X.shape[-1])
|
|
return _nearest_gather(X, x_out, y_out)
|
|
elif mode == "linear":
|
|
x_out, y_out = _coordinate_transformation(
|
|
x_out, y_out, output_shape_, scales, roi
|
|
)
|
|
ret = []
|
|
for y in safe_numpy(y_out):
|
|
for x in safe_numpy(x_out):
|
|
x_floor, y_floor = int(x), int(y)
|
|
y_shrink = (
|
|
(0, X.shape[2])
|
|
if X.shape[2] == 1
|
|
else (y_floor, y_floor + 2)
|
|
if y != y_floor
|
|
else (y_floor, y_floor + 1)
|
|
)
|
|
x_shrink = (
|
|
(x_floor, x_floor + 2) if x != x_floor else (x_floor, x_floor + 1)
|
|
)
|
|
shrink_args = ((0, X.shape[0]), (0, X.shape[1]), y_shrink, x_shrink)
|
|
corners = safe_numpy(X.shrink(shrink_args))
|
|
x1, x2, y1, y2 = x_floor, x_floor + 1, y_floor, y_floor + 1
|
|
if (
|
|
x == x_floor and y == y_floor
|
|
): # TODO https://en.wikipedia.org/wiki/Bilinear_interpolation#Weighted_mean maybe do weighted mean?
|
|
ret.append(corners[0, 0, 0, 0])
|
|
elif x == x_floor:
|
|
ret.append(
|
|
(
|
|
corners[0, 0, 0, 0] * (y2 - y)
|
|
+ corners[0, 0, 1, 0] * (y - y1)
|
|
)
|
|
/ (y2 - y1)
|
|
)
|
|
elif y == y_floor:
|
|
ret.append(
|
|
(
|
|
corners[0, 0, 0, 0] * (x2 - x)
|
|
+ corners[0, 0, 0, 1] * (x - x1)
|
|
)
|
|
/ (x2 - x1)
|
|
)
|
|
else:
|
|
ret.append(
|
|
(
|
|
corners[0, 0, 0, 0] * (x2 - x) * (y2 - y)
|
|
+ corners[0, 0, 0, 1] * (x - x1) * (y2 - y)
|
|
+ corners[0, 0, 1, 0] * (x2 - x) * (y - y1)
|
|
+ corners[0, 0, 1, 1] * (x - x1) * (y - y1)
|
|
)
|
|
/ ((x2 - x1) * (y2 - y1))
|
|
)
|
|
return Tensor(ret).reshape(output_shape)
|
|
elif mode == "cubic":
|
|
raise Exception("cubic interpolation is not implemented")
|
|
|
|
|
|
def CenterCropPad(input: Tensor, shape: Tensor, axes=None):
|
|
if not axes:
|
|
axes = list(range(input.ndim))
|
|
shrink_arg = [(0, i) for i in input.shape]
|
|
pad_arg = [(0, 0) for _ in range(input.ndim)]
|
|
shape = safe_numpy(shape).tolist()
|
|
for s, x in zip(shape, axes):
|
|
if s < input.shape[x]:
|
|
shrink_arg[x] = (
|
|
(input.shape[x] // 2 - s // 2, input.shape[x] // 2 + s // 2)
|
|
if s % 2 == 0
|
|
else (input.shape[x] // 2 - s // 2 - 1, input.shape[x] // 2 + s // 2)
|
|
)
|
|
elif s > input.shape[x]:
|
|
pad_arg[x] = (
|
|
((s - input.shape[x]) // 2, (s - input.shape[x]) // 2)
|
|
if (s - input.shape[x]) % 2 == 0
|
|
else ((s - input.shape[x]) // 2, (s - input.shape[x]) // 2 + 1)
|
|
)
|
|
return input.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
|
|
|
|
|
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
|
depth = int(safe_numpy(depth).item())
|
|
indices, rank = (indices < 0).where(indices + depth, indices), len(indices.shape)
|
|
if axis < 0:
|
|
axis += rank + 1
|
|
ls, rs = indices.shape[0:axis], indices.shape[axis:rank]
|
|
cond = indices[:, None] == Tensor.arange(depth).reshape(
|
|
(1,) * len(ls) + (depth,) + (1,) * len(rs)
|
|
)
|
|
return cond.where(values[1], values[0]).cast(values.dtype)
|
|
|
|
|
|
def Erf(x: Tensor):
|
|
sign = x.sign()
|
|
x = x.abs()
|
|
t = 1.0 / (1.0 + 0.3275911 * x)
|
|
term1 = 0.254829592 * t
|
|
term2 = -0.284496736 * t**2
|
|
term3 = 1.421413741 * t**3
|
|
term4 = -1.453152027 * t**4
|
|
term5 = 1.061405429 * t**5
|
|
y = term1 + term2 + term3 + term4 + term5
|
|
return sign * (1.0 - y * Tensor.exp(-x * x))
|
|
|
|
|
|
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
|
if axis is None:
|
|
inp = inp.flatten()
|
|
axis = 0
|
|
|
|
axis = axis + inp.ndim if axis < 0 else axis
|
|
|
|
con_np = safe_numpy(condition)
|
|
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
|
|
return inp.__getitem__(
|
|
tuple([slice(None) if i != axis else con for i in range(inp.ndim)])
|
|
)
|
|
|
|
|
|
type_map = {TensorProto.DOUBLE: dtypes.double, TensorProto.FLOAT: dtypes.float32}
|
|
|
|
|
|
def EyeLike(x: Tensor, dtype=None, k=0):
|
|
if dtype is None:
|
|
dtype = x.dtype
|
|
else:
|
|
dtype = type_map[dtype]
|
|
shape = x.shape
|
|
dim = min(x.shape)
|
|
if shape[0] == shape[1]:
|
|
return Tensor.eye(dim=dim, dtype=dtype)
|
|
else:
|
|
diff = (shape[0] - dim, shape[1] - dim)
|
|
padarg = tuple([(d, d) if d == 0 else (k, d - k) for d in diff])
|
|
return Tensor.eye(dim=dim, dtype=dtype).pad(padarg)
|
|
|
|
|
|
def Upsample(X, scales, mode):
|
|
return Resize(X=X, scales=scales, mode=mode)
|
|
|
|
|
|
# Needs work
|
|
def IsInf(x: Tensor, detect_negative=1, detect_positive=1):
|
|
ret = (
|
|
(x == float("inf")) * detect_positive
|
|
+ (x == float("-inf")) * detect_negative
|
|
+ Tensor.zeros(*x.shape)
|
|
)
|
|
return ret.cast(dtypes.bool)
|
|
|
|
|
|
def DequantizeLinear(
|
|
x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int] = 0, axis=1
|
|
):
|
|
axis = axis + x.ndim if axis < 0 else axis
|
|
x = x.cast(dtypes.float)
|
|
if x_zero_point.__class__ is Tensor:
|
|
x_zero_point.cast(dtypes.float)
|
|
x_sc = x_scale.reshape(
|
|
*[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim)
|
|
)
|
|
x_zer = (
|
|
x_zero_point.reshape(
|
|
*[1] * axis, *x_scale.shape, *[1] * (x.ndim - axis - x_scale.ndim)
|
|
)
|
|
if isinstance(x_zero_point, Tensor)
|
|
else x_zero_point
|
|
)
|
|
return ((x - x_zer) * x_sc).cast(x_scale.dtype)
|
|
|
|
|
|
# Needs work
|
|
def IsNaN(x: Tensor):
|
|
return (x < float("-inf")).cast(dtypes.bool)
|
|
|
|
|
|
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
|
|
# without importing PIL we'll have to manually decode a bunch of image formats like PNG, JPEG, WebP, etc
|
|
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
|
|
try:
|
|
import PIL.Image
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Pillow must be installed to use the reference implementation of the ImageDecoder operator"
|
|
) from e
|
|
img = PIL.Image.open(io.BytesIO(safe_numpy(encoded_stream).tobytes()))
|
|
if pixel_format == "BGR":
|
|
return Tensor(np.array(img))[:, :, ::-1]
|
|
elif pixel_format == "RGB":
|
|
return Tensor(np.array(img))
|
|
elif pixel_format == "Grayscale":
|
|
img = img.convert("L")
|
|
decoded = Tensor(np.array(img))
|
|
return decoded.unsqueeze(-1) # (H, W) to (H, W, 1)
|
|
else:
|
|
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
|
|
|
|
|
def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
|
|
_, _, *data_sz = safe_numpy(size).tolist()
|
|
size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz)
|
|
stackable = [original_grid]
|
|
for dim, dim_sz in enumerate(data_sz):
|
|
a = (
|
|
Tensor.arange(-1, 1.0001, 2 / (dim_sz - 1))
|
|
if align_corners == 1
|
|
else Tensor.arange(-1 + 1 / dim_sz, 1, 2 / dim_sz)
|
|
)
|
|
if dim == 0:
|
|
stackable = [
|
|
a.reshape(dim_sz, *[1] * (len(data_sz) - 1)) + size_zeros,
|
|
*stackable,
|
|
]
|
|
elif dim == 1:
|
|
stackable = [
|
|
a.reshape(1, dim_sz, *[1] * (len(data_sz) - 2)) + size_zeros,
|
|
*stackable,
|
|
]
|
|
else:
|
|
stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable]
|
|
original_grid = Tensor.stack(stackable, dim=len(data_sz))
|
|
if original_grid.ndim == 3:
|
|
N, dim_2d, dim_homo = theta.shape
|
|
assert dim_2d == 2 and dim_homo == 3
|
|
H, W, dim_homo = original_grid.shape
|
|
assert dim_homo == 3
|
|
original_grid = original_grid.reshape(H * W, dim_homo).transpose()
|
|
return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, H, W, dim_2d)
|
|
else:
|
|
assert original_grid.ndim == 4
|
|
N, dim_3d, dim_homo = theta.shape
|
|
assert dim_3d == 3 and dim_homo == 4
|
|
D, H, W, dim_homo = original_grid.shape
|
|
assert dim_homo == 4
|
|
original_grid = original_grid.reshape(D * H * W, dim_homo).transpose()
|
|
return theta.matmul(original_grid).permute(0, 2, 1).reshape(N, D, H, W, dim_3d)
|
|
|
|
|
|
# **************** com.microsoft Ops ****************
|
|
|
|
|
|
def SkipLayerNormalization(
|
|
input: Tensor,
|
|
skip: Tensor,
|
|
gamma,
|
|
beta: Optional[Tensor] = None,
|
|
bias: Optional[Tensor] = None,
|
|
epsilon=None,
|
|
):
|
|
if epsilon is None:
|
|
epsilon = 1e-12
|
|
x = input + skip + bias
|
|
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
|
|
|
|
|
def FastGelu(x: Tensor, bias: Optional[Tensor] = None):
|
|
x = x + bias
|
|
return 0.5 * x * (1 + (x * 0.797885 + 0.035677 * x**3).tanh())
|
|
|
|
|
|
def EmbedLayerNormalization(
|
|
input_ids: Tensor,
|
|
segment_ids: Optional[Tensor] = None,
|
|
word_embedding: Tensor = None,
|
|
position_embedding: Tensor = None,
|
|
segment_embedding: Optional[Tensor] = None,
|
|
gamma=None,
|
|
beta=None,
|
|
mask: Optional[Tensor] = None,
|
|
position_ids: Optional[Tensor] = None,
|
|
epsilon=None,
|
|
mask_index_type=None,
|
|
):
|
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.EmbedLayerNormalization
|
|
assert (segment_ids is None) is (segment_embedding is None)
|
|
assert (mask is None) is (mask_index_type is None)
|
|
assert mask is None, "functionality not supported yet" # TODO
|
|
input_shape = input_ids.shape
|
|
bsz, seq_length = input_shape[0], input_shape[1]
|
|
compute_seg_emb = segment_embedding is not None and segment_ids is not None
|
|
vocab_size, max_position_embeddings, type_vocab_size = (
|
|
word_embedding.shape[0],
|
|
position_embedding.shape[0],
|
|
(segment_embedding.shape[0] if compute_seg_emb else None),
|
|
)
|
|
|
|
def embedding(
|
|
x: Tensor, vocab_size, weight: Tensor
|
|
) -> Tensor: # TODO from nn.Embedding. Could probably upstream this to Tensor
|
|
vocab_counter = (
|
|
Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False)
|
|
.reshape(1, 1, vocab_size)
|
|
.expand(*x.shape, vocab_size)
|
|
)
|
|
return (vocab_counter == x.unsqueeze(2).expand(*x.shape, vocab_size)) @ weight
|
|
|
|
# bert embedding layer
|
|
if epsilon is None:
|
|
epsilon = 1e-12
|
|
if position_ids is None:
|
|
position_ids = (
|
|
Tensor.arange(seq_length, requires_grad=False)
|
|
.unsqueeze(0)
|
|
.expand(*input_shape)
|
|
)
|
|
wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding)
|
|
pos_embedding_res = embedding(
|
|
position_ids, max_position_embeddings, position_embedding
|
|
)
|
|
seg_embedding_res = (
|
|
embedding(segment_ids, type_vocab_size, segment_embedding)
|
|
if compute_seg_emb
|
|
else None
|
|
)
|
|
|
|
embedding_sum = wrd_embedding_res + pos_embedding_res + seg_embedding_res
|
|
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
|
return out, None, embedding_sum
|
|
|
|
|
|
def Attention(
|
|
input: Tensor,
|
|
weights,
|
|
bias: Optional[Tensor] = None,
|
|
mask_index: Optional[Tensor] = None,
|
|
past: Optional[Tensor] = None,
|
|
relative_position_bias: Optional[Tensor] = None,
|
|
past_sequence_length: Optional[Tensor] = None,
|
|
do_rotary=None,
|
|
mask_filter_value=None,
|
|
num_heads=None,
|
|
past_present_share_buffer=None,
|
|
qkv_hidden_sizes=None,
|
|
scale=None,
|
|
unidirectional=None,
|
|
):
|
|
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
|
|
assert num_heads is not None # required
|
|
assert (qkv_hidden_sizes is None and past is not None) or (
|
|
qkv_hidden_sizes is not None
|
|
)
|
|
assert (
|
|
relative_position_bias
|
|
== do_rotary
|
|
== past_sequence_length
|
|
== mask_filter_value
|
|
== past_present_share_buffer
|
|
== scale
|
|
== None
|
|
), "functionality not supported yet" # TODO strange params
|
|
hidden_size, v_hidden_size = (
|
|
qkv_hidden_sizes[1:]
|
|
if qkv_hidden_sizes is not None
|
|
else 2 * (weights.shape[1] // 3,)
|
|
)
|
|
|
|
if unidirectional: # gpt-style
|
|
assert hidden_size == v_hidden_size
|
|
xqkv = input.linear(weights, bias)
|
|
xq, xk, xv = [
|
|
xqkv.slice([None, None, (i * hidden_size, (i + 1) * hidden_size)])
|
|
for i in range(3)
|
|
]
|
|
else: # bert-style
|
|
wq, wk, wv = (
|
|
weights[:, :hidden_size],
|
|
weights[:, hidden_size : hidden_size + v_hidden_size],
|
|
weights[:, hidden_size + v_hidden_size :],
|
|
)
|
|
bq, bk, bv = (
|
|
(
|
|
bias[:hidden_size],
|
|
bias[hidden_size : hidden_size + v_hidden_size],
|
|
bias[hidden_size + v_hidden_size],
|
|
)
|
|
if bias is not None
|
|
else None
|
|
)
|
|
xq, xk, xv = [input.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
|
|
xq, xk, xv = [
|
|
x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2)
|
|
for x in (xq, xk, xv)
|
|
]
|
|
|
|
if past is not None:
|
|
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
|
|
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
|
|
|
|
def attn(query, key, value, attn_mask):
|
|
query_length, key_length = query.shape[-2], key.shape[-2]
|
|
cdim = max(query_length, key_length) + 1
|
|
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
|
|
# This is where Tensor.scaled_dot_product_attention differs:
|
|
causal_mask = (
|
|
Tensor.ones((cdim, cdim), requires_grad=False)
|
|
.cast(dtypes.bool)
|
|
.tril(0)[key_length - query_length : key_length, :key_length]
|
|
.cast(dtypes.bool)
|
|
)
|
|
return (
|
|
Tensor.where(causal_mask, attn_weights, -float("inf")) + attn_mask
|
|
).softmax(-1) @ value
|
|
|
|
bsz, _, seq_len, _ = xq.shape
|
|
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
|
return out, present
|
|
|
|
|
|
# **************** ai.onnx.preview.training Ops ****************
|
|
|
|
|
|
# TODO not entirely sure these optimizers are correct
|
|
def Adagrad(R, T, *inputs, decay_factor=0.0, epsilon=0.0, norm_coefficient=0.0):
|
|
groups = len(inputs) // 3
|
|
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
|
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
|
r = R / (1 + T * decay_factor)
|
|
ret = []
|
|
for input in grouped_inputs:
|
|
X, G, H = input
|
|
X.grad = norm_coefficient * X + G
|
|
X.grad.requires_grad, H.requires_grad = (
|
|
False,
|
|
False,
|
|
) # TODO manually turning off requires_grad, see TODO under (domain == "ai.onnx.preview.training") in onnx.py
|
|
H.assign(H.detach() + X.grad * X.grad).realize()
|
|
H_adaptive = H.sqrt() + epsilon
|
|
X.assign(X.detach() - r * X.grad / H_adaptive)
|
|
ret.extend([X, H])
|
|
ret = ret[::2] + ret[1::2]
|
|
return tuple(ret)
|
|
|
|
|
|
def Momentum(R, T, *inputs, alpha, beta, mode, norm_coefficient):
|
|
groups = len(inputs) // 3
|
|
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
|
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
|
beta_adjusted = beta if T > 0 else 1
|
|
ret = []
|
|
for input in grouped_inputs:
|
|
X, G, V = input
|
|
X.grad = (norm_coefficient * X + G).realize()
|
|
X.grad.requires_grad, V.requires_grad = False, False
|
|
V.assign(alpha * V + beta_adjusted * X.grad).realize()
|
|
if mode == "standard":
|
|
X.assign(X.detach() - R * V).realize()
|
|
elif mode == "nesterov":
|
|
X.assign(X.detach() - R * (X.grad + alpha + V)).realize()
|
|
ret.extend([X, V])
|
|
ret = ret[::2] + ret[1::2]
|
|
return tuple(ret)
|
|
|
|
|
|
# copied from tinygrad/nn/optim.py: LAMB with some edits
|
|
def Adam(
|
|
R,
|
|
T,
|
|
*inputs,
|
|
alpha=0.9,
|
|
beta=0.999,
|
|
epsilon=0.0,
|
|
norm_coefficient=0.0,
|
|
norm_coefficient_post=0.0,
|
|
):
|
|
groups = len(inputs) // 4
|
|
grouped_inputs = [inputs[i::groups] for i in range(groups)]
|
|
T, R = safe_numpy(T)[0], safe_numpy(R)[0]
|
|
ret = []
|
|
for input in grouped_inputs:
|
|
X, G, V, H = input
|
|
X.grad = (norm_coefficient * X + G).realize()
|
|
V.requires_grad, H.requires_grad, X.grad.requires_grad = False, False, False
|
|
V.assign(alpha * V + (1.0 - alpha) * X.grad).realize()
|
|
H.assign(beta * H + (1.0 - beta) * (X.grad * X.grad)).realize()
|
|
up = (
|
|
(V / (1.0 - alpha**T)) / ((H / (1.0 - beta**T)).sqrt() + epsilon)
|
|
if T > 0
|
|
else V / (H.sqrt() + epsilon)
|
|
)
|
|
X.assign(X.detach() - R * up).realize()
|
|
X = (1 - norm_coefficient_post) * X
|
|
ret.extend([X, V, H])
|
|
ret = ret[::3] + ret[1::3] + ret[2::3]
|
|
return tuple(ret)
|