1
0
Fork 0
tinygrab/tinygrad/tensor.py

3378 lines
125 KiB
Python

# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math
from typing import (
List,
Tuple,
Callable,
Optional,
ClassVar,
Type,
Union,
Sequence,
Any,
Iterable,
Set,
)
from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from tinygrad.helpers import (
ImageDType,
argfix,
make_pair,
getenv,
IMAGE,
DEBUG,
flatten,
DType,
dtypes,
prod,
all_int,
round_up,
)
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule
class Function:
"""
Class for creating a function in the autograd system.
Attributes:
device (str): The device where the operation takes place (e.g., "cpu", "cuda").
tensors (List[Tensor]): A list of input tensors for the operation.
needs_input_grad (List[bool]): List indicating whether an input tensor requires gradient computation.
requires_grad (Union[bool, None]): Indicates whether the output tensor requires gradient computation.
parents (List[Tensor]): The parent tensors for which gradients can be computed.
"""
def __init__(self, device: str, *tensors: Tensor):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = (
True
if any(self.needs_input_grad)
else None
if None in self.needs_input_grad
else False
)
if self.requires_grad:
self.parents = tensors
def forward(self, *args, **kwargs):
"""
Forward method to be implemented by subclasses. This method is called when the function is called in forward mode.
Raises:
NotImplementedError: If not overridden in a subclass.
"""
raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs):
"""
Backward method to be implemented by subclasses. This method is called when the function is called in backward mode.
Raises:
RuntimeError: If not overridden in a subclass.
"""
raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn: Type[Function], *x: Tensor, **kwargs) -> Tensor:
"""
Class method to apply the function on a tensor. Creates a context for the operation and computes its result.
Args:
cls (Type[Function]): The class of the function being applied.
fxn (Type[Function]): The function to be applied.
x (List[Tensor]): The input tensors for the operation.
kwargs (Dict[str, Any]): Additional keyword arguments for the operation.
Returns:
Tensor: The result of the operation.
"""
ctx = fxn(x[0].device, *x)
ret = Tensor(
ctx.forward(*[t.lazydata for t in x], **kwargs),
device=ctx.device,
requires_grad=ctx.requires_grad,
)
if ctx.requires_grad and not Tensor.no_grad:
ret._ctx = ctx # used by autograd engine
return ret
import tinygrad.mlops as mlops
# **** start with two base classes, Tensor and Function ****
class Tensor:
"""
This class represents a tensor, which is the fundamental unit of data in tinygrad.
It can be used for various mathematical operations and machine learning applications.
Attributes:
__slots__ (str): List of attributes that are slotted for this class.
__deletable__ (tuple): Tuple of attributes that can be deleted.
training (ClassVar[bool]): Class variable to track if the tensor is in training mode or not.
no_grad (ClassVar[bool]): Class variable to track if gradient computation is disabled or not.
default_type (ClassVar[DType]): Default data type for tensors.
"""
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ("_ctx",)
training: ClassVar[bool] = False
class train:
def __init__(self, val=True):
self.val = val
def __enter__(self):
self.prev, Tensor.training = Tensor.training, self.val
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
Tensor.training = self.prev
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(
self,
data: Union[None, int, float, list, LazyBuffer, np.ndarray, bytes],
device: Optional[str] = None,
dtype: Optional[DType] = None,
requires_grad: Optional[bool] = None,
):
"""
Constructs a new tensor from the given data with the specified device and data type.
Args:
data (Union[None, int, float, list, LazyBuffer, np.ndarray, bytes]): Data to initialize the tensor.
device (Optional[str]): Device where the tensor will be stored.
dtype (Optional[DType]): Data type of the tensor.
requires_grad (Optional[bool]): Flag indicating if gradient computation is required or not.
"""
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
# tensors have gradients, buffers do not
self.grad: Optional[Tensor] = None
# NOTE: this can be in three states. False and None: no gradient, True: gradient
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer):
assert (
dtype is None or dtype == data.dtype
), "dtype doesn't match, and casting isn't supported"
elif isinstance(data, (int, float)):
data = LazyBuffer.loadop(
LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data
)
elif data is None or data.__class__ is list:
assert (
dtype is None or dtype.np is not None
), f"{dtype} doesn't have a numpy dtype"
data = LazyBuffer.fromCPU(
np.array(
[] if data is None else data,
dtype=(dtype or Tensor.default_type).np,
)
)
elif isinstance(data, bytes):
data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
elif isinstance(data, np.ndarray):
assert (
dtype is None or dtype.np is not None
), f"{dtype} doesn't have a numpy dtype"
if data.shape == ():
data = LazyBuffer.loadop(
LoadOps.CONST,
tuple(),
dtype or dtypes.from_np(data.dtype),
device,
data.item(),
)
else:
data = LazyBuffer.fromCPU(
data.astype(dtype.np)
if dtype is not None and dtype.np is not None
else data
)
# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, LazyBuffer):
raise RuntimeError(
f"can't create Tensor from {data!r} with type {type(data)}"
)
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self):
"""
Return a string representation of the Tensor object.
Returns:
str: A string containing information about the tensor, including its lazydata, device, and gradient (if available).
"""
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self):
"""
Define the hash of the Tensor object to be its memory address.
Returns:
int: The unique memory address (id) of the tensor object.
"""
return id(self)
@property
def device(self) -> str:
"""
Retrieve the device attribute from the lazydata of the tensor.
Returns:
str: The device where the tensor is stored (e.g., 'cpu', 'cuda').
"""
return self.lazydata.device
@property
def shape(self) -> Tuple[sint, ...]:
"""
Retrieve the shape attribute from the lazydata of the tensor.
Returns:
tuple: A tuple representing the dimensions of the tensor.
"""
return self.lazydata.shape
@property
def dtype(self) -> DType:
"""
Retrieve the dtype attribute from the lazydata of the tensor.
Returns:
DType: The data type of the tensor (e.g., float32, int64).
"""
return self.lazydata.dtype
# ***** data handlers ****
@staticmethod
def corealize(lst: Iterable[Tensor]):
"""
Realize a list of tensors.
This method takes an iterable collection of tensors and realizes them one by one.
:param lst: An iterable collection of tensors to be realized.
:type lst: Iterable[Tensor]
"""
seen: Set[LazyBuffer] = set()
sched = []
for t in lst:
sched += t.lazydata.schedule(seen)
run_schedule(sched)
def realize(self) -> Tensor:
"""
Realize the tensor.
This method realizes the tensor by running a schedule on its lazy data. The realized tensor is then returned.
:return: The realized tensor.
:rtype: Tensor
"""
run_schedule(self.lazydata.schedule())
return self
def assign(self, x) -> Tensor:
"""
Assign a value to the tensor.
This method assigns a value to the tensor. It handles various cases such as when the tensor is not already a
Tensor object or when it has a 'DISK' device. If the tensor requires gradient, an assertion error will be raised.
:param x: The value to be assigned to the tensor.
:type x: Any
:return: The tensor with the assigned value.
:rtype: Tensor
"""
# TODO: this is a hack for writing to DISK. remove with working assign
if self.device.startswith("DISK"):
if x.__class__ is not Tensor:
x = Tensor(x, device="CPU", dtype=self.dtype)
self.contiguous().realize().lazydata.realized.copyin(x.numpy().data)
return self
if x.__class__ is not Tensor:
x = Tensor(x, device=self.device, dtype=self.dtype)
assert (
self.shape == x.shape and self.device == x.device
), f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4:
print(f"assign {self.lazydata} <- {x.lazydata}")
if (
self.dtype == x.dtype
and self.lazydata.realized is not None
and not getenv("DISALLOW_ASSIGN")
):
x.lazydata.output_buffer = self.lazydata.realized
self.lazydata = x.lazydata
return self
def detach(self) -> Tensor:
"""
Detaches the tensor from its current computation graph, making it a leaf node.
Returns:
Tensor: The detached tensor.
"""
return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self) -> np.ndarray:
"""
Converts the tensor to a NumPy array.
Raises:
AssertionError: If the shape is symbolic or the dtype cannot be represented in NumPy.
Returns:
np.ndarray: The NumPy equivalent of this tensor.
"""
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
if 0 in self.shape:
return np.zeros(self.shape, dtype=self.dtype.np)
return (
self.detach()
.cast(dtypes.from_np(self.dtype.np))
.contiguous()
.to("CPU")
.realize()
.lazydata.realized.toCPU()
.astype(self.dtype.np, copy=True)
.reshape(self.shape)
)
def item(self) -> Union[float, int]:
"""
Returns the tensor as a Python scalar.
Raises:
AssertionError: If the number of elements in the tensor is not 1.
Returns:
Union[float, int]: The tensor's value as a Python scalar.
"""
assert self.numel() == 1, "must have one element for item"
return self.realize().lazydata.realized.toCPU().item()
def to(self, device: Optional[str]) -> Tensor:
"""
Moves the tensor to a specified device (if different from its current device).
Args:
device (Optional[str]): The target device. If None or equal to the current device, does nothing.
Returns:
Tensor: The tensor on the target device.
"""
if device is None or device == self.device:
return self
ret = Tensor(self.lazydata, device)
if self.grad:
ret.grad = self.grad.to(device)
return ret
def to_(self, device: Optional[str]):
"""
Moves the tensor in-place to a specified device (if different from its current device).
Args:
device (Optional[str]): The target device. If None or equal to the current device, does nothing.
Returns:
None: Modifies the tensor in-place.
"""
if device is None or device == self.device:
return
if self.grad:
self.grad = self.grad.to_(device)
_ret = Tensor(self.lazydata, device)
self.lazydata = _ret.lazydata
# ***** creation llop entrypoint *****
@staticmethod
def _loadop(
op,
sz,
device: Optional[str] = None,
dtype: Optional[DType] = None,
arg=None,
**kwargs,
):
"""
Load operation method.
Attributes:
op (): Operation to be performed.
sz (int): Size of the tensor.
device (Optional[str], optional): Device where the tensor will be stored. Defaults to None.
dtype (Optional[DType], optional): Data type of the tensor. Defaults to None.
arg: Additional argument. Defaults to None.
Returns:
Tensor: Constructed tensor.
"""
assert isinstance(sz, int), f"cannot create with symbolic size {sz}"
return Tensor(
LazyBuffer.loadop(
op,
(sz,),
Tensor.default_type if dtype is None else dtype,
Device.canonicalize(device),
arg,
),
dtype=dtype,
device=device,
**kwargs,
)
@staticmethod
def empty(*shape, **kwargs):
"""
Create an uninitialized tensor.
Attributes:
shape (tuple): Shape of the tensor.
Returns:
Tensor: Constructed tensor.
"""
return Tensor._loadop(
LoadOps.EMPTY, prod((shape := argfix(*shape))), **kwargs
).reshape(shape)
_seed: int = int(time.time())
@staticmethod
def manual_seed(seed=0):
"""
Set the seed for generating random numbers.
Attributes:
seed (int, optional): Seed value. Defaults to 0.
"""
Tensor._seed = seed
@staticmethod
def rand(*shape, **kwargs):
"""
Create a tensor with random elements.
:param shape: The shape of the desired tensor.
:type shape: Tuple[int, ...]
:param kwargs: Additional keyword arguments for LoadOps.
:return: A tensor filled with random values.
:rtype: Tensor
"""
return Tensor._loadop(
LoadOps.CUSTOM, prod((shape := argfix(*shape))), arg=custom_random, **kwargs
).reshape(shape)
# ***** creation helper functions *****
@staticmethod
def full(shape: Tuple[sint, ...], fill_value, **kwargs):
"""
Create a tensor filled with a specified value.
:param shape: The shape of the desired tensor.
:type shape: Tuple[int, ...]
:param fill_value: The value to fill the tensor with.
:type fill_value: int or float
:param kwargs: Additional keyword arguments for Tensor creation.
:return: A tensor filled with the specified value.
:rtype: Tensor
"""
return (
Tensor(fill_value, **kwargs)
.reshape([1] * len(new_shape := argfix(shape)))
.expand(new_shape)
)
@staticmethod
def zeros(*shape, **kwargs):
"""
Create a tensor filled with zeros.
:param shape: The shape of the desired tensor.
:type shape: Tuple[int, ...]
:param kwargs: Additional keyword arguments for Tensor creation.
:return: A tensor filled with zeros.
:rtype: Tensor
"""
return Tensor.full(argfix(*shape), 0, **kwargs)
@staticmethod
def ones(*shape, **kwargs):
"""
Create a tensor filled with ones.
:param shape: The shape of the desired tensor.
:type shape: Tuple[int, ...]
:param kwargs: Additional keyword arguments for Tensor creation.
:return: A tensor filled with ones.
:rtype: Tensor
"""
return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
"""
Create a tensor with evenly spaced values within a specified range.
:param start: The start of the range (inclusive).
:type start: int or float
:param stop: The end of the range (exclusive, if not specified, `start` is set to 0).
:type stop: int or float or None
:param step: The spacing between values. Default is 1.
:type step: int or float
:param kwargs: Additional keyword arguments for Tensor creation.
:return: A tensor with evenly spaced values within the specified range.
:rtype: Tensor
"""
if stop is None:
stop, start = start, 0
return Tensor.full(
(math.ceil((stop - start) / step),), step, **kwargs
).cumsum() + (start - step)
@staticmethod
def eye(dim: int, **kwargs):
"""
Create an identity matrix of the specified dimension.
:param dim: The number of rows and columns in the identity matrix.
:type dim: int
:param kwargs: Additional keyword arguments for Tensor creation.
:return: An identity matrix of the specified dimension.
:rtype: Tensor
"""
return (
Tensor.full((dim, 1), 1, **kwargs)
.pad(((0, 0), (0, dim)))
.reshape(dim * (dim + 1))
.shrink(((0, dim * dim),))
.reshape(dim, dim)
)
def full_like(self, fill_value, **kwargs):
"""
Creates a tensor filled with the specified `fill_value`. The shape of the new tensor is
determined by the shape of the calling tensor. The data type and device can be optionally
specified using keyword arguments. If not provided, they default to the data type and device
of the calling tensor.
Attributes:
fill_value (Any): Value to fill the new tensor with.
**kwargs: Keyword arguments for specifying additional parameters such as data type
(`dtype`) and device (`device`).
Returns:
Tensor: A new tensor filled with `fill_value`.
"""
return Tensor.full(
self.shape,
fill_value=fill_value,
dtype=kwargs.pop("dtype", self.dtype),
device=kwargs.pop("device", self.device),
**kwargs,
)
def zeros_like(self, **kwargs):
"""
Creates a tensor filled with zeros. The shape of the new tensor is determined by the
shape of the calling tensor. The data type and device can be optionally specified using
keyword arguments. If not provided, they default to the data type and device of the calling
tensor.
Attributes:
**kwargs: Keyword arguments for specifying additional parameters such as data type
(`dtype`) and device (`device`).
Returns:
Tensor: A new tensor filled with zeros.
"""
return self.full_like(0, **kwargs)
def ones_like(self, **kwargs):
"""
Creates a tensor filled with ones. The shape of the new tensor is determined by the
shape of the calling tensor. The data type and device can be optionally specified using
keyword arguments. If not provided, they default to the data type and device of the calling
tensor.
Attributes:
**kwargs: Keyword arguments for specifying additional parameters such as data type
(`dtype`) and device (`device`).
Returns:
Tensor: A new tensor filled with ones.
"""
return self.full_like(1, **kwargs)
# ***** rng hlops *****
@staticmethod
def randn(*shape, dtype: Optional[DType] = None, **kwargs) -> Tensor:
"""
Generates a tensor of the specified shape filled with random numbers from a normal
distribution (mean=0, standard deviation=1). The data type can be optionally specified
using a keyword argument. If not provided, it defaults to the default data type.
Attributes:
*shape (int): Shape of the new tensor.
dtype (Optional[DType]): Optional data type for the new tensor. Defaults to the
default data type if not specified.
**kwargs: Keyword arguments for specifying additional parameters.
Returns:
Tensor: A new tensor filled with random numbers from a normal distribution.
"""
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return (
src[0]
.mul(2 * math.pi)
.cos()
.mul((1 - src[1]).log().mul(-2).sqrt())
.cast(Tensor.default_type if dtype is None else dtype)
)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
"""
Generates a tensor of the specified shape filled with random integers from a uniform
distribution within the range [`low`, `high`). The data type can be optionally specified
using a keyword argument. If not provided, it defaults to int32.
Attributes:
*shape (int): Shape of the new tensor.
low (int): Lower bound of the uniform distribution. Defaults to 0.
high (int): Upper bound of the uniform distribution. Defaults to 10.
**kwargs: Keyword arguments for specifying additional parameters.
Returns:
Tensor: A new tensor filled with random integers from a uniform distribution.
"""
return (Tensor.rand(*shape, **kwargs) * (high - low) + low).cast(dtypes.int32)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a normal distribution.
Attributes:
shape (tuple): The shape of the output tensor.
mean (float): The mean of the normal distribution. Default is 0.0.
std (float): The standard deviation of the normal distribution. Default is 1.0.
Returns:
Tensor: A tensor with random values sampled from a normal distribution.
"""
return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a uniform distribution.
Attributes:
shape (tuple): The shape of the output tensor.
low (float): The lower bound of the uniform distribution. Default is 0.0.
high (float): The upper bound of the uniform distribution. Default is 1.0.
Returns:
Tensor: A tensor with random values sampled from a uniform distribution.
"""
dtype = kwargs.pop("dtype", Tensor.default_type)
return ((high - low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a uniform distribution and scale it by `prod(shape)**-0.5`.
Attributes:
shape (tuple): The shape of the output tensor.
Returns:
Tensor: A tensor with random values sampled from a uniform distribution and scaled by `prod(shape)**-0.5`.
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(
prod(shape) ** -0.5
)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a uniform distribution according to the Glorot initialization method.
Attributes:
shape (tuple): The shape of the output tensor.
Returns:
Tensor: A tensor with random values sampled from a uniform distribution according to the Glorot initialization method.
"""
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(
(6 / (shape[0] + prod(shape[1:]))) ** 0.5
)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a: float = 0.01, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a uniform distribution according to the Kaiming initialization method for weights.
Attributes:
shape (tuple): The shape of the output tensor.
a (float): The negative slope of the rectifier used after this layer. Default is 0.01.
Returns:
Tensor: A tensor with random values sampled from a uniform distribution according to the Kaiming initialization method for weights.
"""
bound = (
math.sqrt(3.0) * math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:]))
)
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a: float = 0.01, **kwargs) -> Tensor:
"""
Generate a tensor with random values sampled from a normal distribution according to the Kaiming initialization method for weights.
Attributes:
shape (tuple): The shape of the output tensor.
a (float): The negative slope of the rectifier used after this layer. Default is 0.01.
Returns:
Tensor: A tensor with random values sampled from a normal distribution according to the Kaiming initialization method for weights.
"""
std = math.sqrt(2.0 / (1 + a**2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(
self: Tensor, num_samples: int = 1, replacement: bool = False
) -> Tensor:
"""
Draw samples from a multinomial distribution.
Args:
self (Tensor): Input tensor of shape 1 or 2 dimensions.
num_samples (int): Number of samples to draw, must be positive. Default is 1.
replacement (bool): If True, sample with replacement. Default is False.
Returns:
Tensor: The drawn samples.
Raises:
AssertionError: If the input tensor has an unsupported number of dimensions or num_samples is not positive.
AssertionError: If no replacement is requested, but num_samples > 1.
Attributes:
weight (Tensor): The input tensor reshaped to have one more dimension.
cdf (Tensor): The cumulative distribution function of the input tensor.
unif_samples (Tensor): Uniformly distributed samples in the range [0, 1).
indices (Tensor): Indices of the drawn samples.
"""
assert (
1 <= self.ndim <= 2 and num_samples > 0
), f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive"
assert (
replacement or num_samples == 1
), "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
indices = (
(unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
)
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
def deepwalk(self):
"""
Perform a depth-first search on the computation graph starting from this tensor.
Returns:
List[Tensor]: A list of tensors in topological order (deepest first).
"""
def _deepwalk(node, visited, nodes):
visited.add(node)
if getattr(node, "_ctx", None):
for i in node._ctx.parents:
if i not in visited:
_deepwalk(i, visited, nodes)
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])
def backward(self) -> Tensor:
"""
Compute the gradient of this tensor wrt its inputs.
Returns:
Tensor: This tensor.
Raises:
AssertionError: If this tensor is not a scalar or has no gradient.
"""
assert (
self.shape == tuple()
), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
assert t0.grad is not None
grads = t0._ctx.backward(t0.grad.lazydata)
grads = [
Tensor(g, device=self.device, requires_grad=False)
if g is not None
else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)
]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert (
g.shape == t.shape
), f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
return self
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
"""
Reshapes the tensor.
Args:
self: The tensor to be reshaped.
shape (Tuple[int, ...]): The desired shape of the tensor.
*args: Additional arguments.
Returns:
Tensor: The reshaped tensor.
"""
new_shape = argfix(shape, *args)
return mlops.Reshape.apply(
self,
shape=tuple(
[
-prod(self.shape) // prod(new_shape)
if s == -1
else (s if s is not None else self.shape[i])
for i, s in enumerate(new_shape)
]
),
)
def expand(self, shape, *args) -> Tensor:
"""
Expands the tensor to a new shape.
Args:
self: The tensor to be expanded.
shape (Tuple[int, ...]): The desired shape of the expanded tensor.
*args: Additional arguments.
Returns:
Tensor: The expanded tensor.
"""
return mlops.Expand.apply(
self,
shape=tuple(
[x if x != -1 else s for s, x in zip(self.shape, argfix(shape, *args))]
),
)
def permute(self, order, *args) -> Tensor:
"""
Permutes the dimensions of the tensor according to a given order.
Args:
self: The tensor to be permuted.
order (Tuple[int, ...]): The desired order of dimensions.
*args: Additional arguments.
Returns:
Tensor: The permuted tensor.
"""
return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor:
"""
Flips the tensor along a given axis or a list of axes.
Args:
self: The tensor to be flipped.
axis (int or Tuple[int, ...]): The axis or axes along which to flip the tensor.
If negative, it counts from the last dimension.
*args: Additional arguments.
Returns:
Tensor: The flipped tensor.
"""
return mlops.Flip.apply(
self,
axis=[x if x >= 0 else x + len(self.shape) for x in argfix(axis, *args)],
)
def shrink(self, arg: Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
"""
Shrinks the tensor along a given dimension or dimensions.
Args:
self: The tensor to be shrunk.
arg (Tuple[Tuple[int, int], ...]): The dimensions and the size of the
shrinking operation for each dimension. If None, no shrinking is performed.
Returns:
Tensor: The shrunken tensor.
"""
return (
mlops.Shrink.apply(
self,
arg=tuple(
x if x is not None else (0, s) for x, s in zip(arg, self.shape)
),
)
if any(x is not None and x != (0, s) for x, s in zip(arg, self.shape))
else self
)
def pad(
self, arg: Tuple[Optional[Tuple[sint, sint]], ...], value: float = 0.0
) -> Tensor:
"""
Pad tensor with specified value.
Parameters:
arg (Tuple[Optional[Tuple[sint, sint]], ...]): The padding size for each dimension.
If None or (0, 0) is provided for a dimension, no padding is added in that dimension.
value (float): The value to fill the padded area with. Default is 0.0.
Returns:
Tensor: The tensor after padding.
Attributes:
self: input tensor to be padded
arg: tuple of padding sizes for each dimension
value: value used for padding
"""
if all(x is None or x == (0, 0) for x in arg):
return self
ret = mlops.Pad.apply(
self, arg=(narg := tuple(x if x is not None else (0, 0) for x in arg))
)
return (
ret
if 0 == value
else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
)
"""
***** movement hlops *****
- Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
- A slice i:j returns the elements with indices in [i, j)
- If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
- Negative values for i and j are taken relative to the end of the sequence
- Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
- Indexing with None on a given axis will add a new dimension of size one before that axis
- Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
- For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
- Strides > 1 and < 0 are now allowed!:
- This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
- Idea of stride < 0 support:
- Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
- Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
- Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
- So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
is possible.
- Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
- Fancy indexing and combined indexing is supported
- Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
- Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
- The first iteration will expand the dim of self while consecutive iterations will reduce the dim
- There's a special case where a permute is needed at the end:
- if first Tensor passed in (expand dims) is not at dim 0
- and following Tensors does not follow consecutively to the end of fancy indexing's dims
"""
def __getitem__(
self, indices
) -> (
Tensor
): # indices: Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
"""
Get a tensor item based on the given index or indices.
This method supports various types of indices such as int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...].
Args:
indices (Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]): The index or indices to get the tensor item.
Returns:
Tensor: The tensor item corresponding to the given index or indices.
"""
def normalize_int(e, i, dim_sz):
"""
Normalize an integer index based on its dimension size.
Args:
e (int): The integer index to be normalized.
i (int): The dimension corresponding to the index.
dim_sz (int): The dimension size.
Returns:
int: The normalized integer index.
"""
def normalize_int(e, i, dim_sz):
"""
Normalize an integer index based on its dimension size.
Args:
e (int): The integer index to be normalized.
i (int): The dimension corresponding to the index.
dim_sz (int): The dimension size.
Returns:
int: The normalized integer index.
"""
if -dim_sz <= e < dim_sz:
return e if e != -1 else dim_sz - 1
raise IndexError(
f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}"
)
# TODO: if indices is a tuple of any sequence, or if indices is a list, it's for advanced indexing
orig_slices = list(indices) if isinstance(indices, tuple) else [indices]
count = defaultdict(list)
for i, v in enumerate(orig_slices):
count[type(v)].append(i)
# TODO: boolean indices
if (
num_slices := len(count[int])
+ len(count[slice])
+ len(count[Tensor])
+ len(count[list])
) > len(self.shape):
raise IndexError(
f"too many indices for tensor of dimension {len(self.shape)}"
)
if len(ellipsis_found := count[type(Ellipsis)]) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
# replace ellipsis with equivalent number of slice(None)
# TODO: move all slice(None) to the end and transpose non-None to the front
ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
orig_slices[ellipsis_idx : ellipsis_idx + 1] = [slice(None)] * (
len(self.shape) - num_slices
)
valid_slices = [v for v in orig_slices if v is not None]
valid_slices = [
v
if isinstance(v, slice)
else slice(y_ := normalize_int(v, i, dim_sz), y_ + 1)
if isinstance(v, int)
else slice(None)
for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))
]
start, stop, strides = (
zip(*y)
if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)])
else ((), (), ())
)
new_slice = tuple(
((0, 0) if e < s else (s, e))
if st > 0
else ((0, 0) if e > s else (e + 1, s + 1))
for s, e, st in zip(start, stop, strides)
)
sliced_tensor = self.shrink(new_slice).flip(
axis=[i for i, s in enumerate(strides) if s < 0]
)
new_shape = sliced_tensor.shape
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
padded_tensor = sliced_tensor.pad(
tuple(
(0, s - (dim_sz % s) if dim_sz % s != 0 else 0)
for s, dim_sz in zip(strides, sliced_tensor.shape)
)
)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
reshaped_tensor = padded_tensor.reshape(
flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))
)
new_shape = reshaped_tensor.shape[::2]
# Shrink: do [:, 0]
sliced_tensor = reshaped_tensor.shrink(
tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))
)
final_shape, it_shape, dim, tensors, dim_collapsed = (
[],
iter(new_shape),
[],
[],
0,
)
for i, s in enumerate(orig_slices):
if s is None:
final_shape.append(1)
else: # s is int or slice or Tensor
dim_shape = next(it_shape)
if isinstance(s, list):
s = Tensor(s)
if isinstance(s, int):
dim_collapsed += 1
else:
assert isinstance(
dim_shape, int
), f"does not support symbolic shape {dim_shape}"
final_shape.append(dim_shape)
if isinstance(s, Tensor):
tensors.append(s)
dim.append(i - dim_collapsed)
ret = sliced_tensor.reshape(tuple(final_shape))
if tensors: # Fancy/tensor indexing
# normalize idx
# TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
idx = [
t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t
for d, t in zip(dim, tensors)
]
max_dim = max(i.ndim for i in idx)
# compute sum_dim, arange, and idx
sum_dim = [d if n == 0 else d + max_dim - n for n, d in enumerate(dim)]
arange = [
Tensor.arange(
ret.shape[d],
dtype=dtypes.int32,
requires_grad=False,
device=self.device,
).reshape(
*[1] * sd, ret.shape[d], *[1] * (ret.ndim + max_dim - n - sd - 1)
)
for n, (sd, d) in enumerate(zip(sum_dim, dim))
]
first_idx = [
idx[0].reshape(
*[1] * dim[0],
*[1] * (1 + max_dim - idx[0].ndim),
*idx[0].shape,
*[1] * (ret.ndim - dim[0] - 1),
)
]
rest_idx = [
i.reshape(
*[1] * dim[0],
*[1] * (max_dim - i.ndim),
*i.shape,
*[1] * (ret.ndim - dim[0] - n),
)
for n, i in enumerate(idx[1:], 1)
]
idx = first_idx + rest_idx
ret = ret.reshape(
*ret.shape[: sum_dim[0] + 1],
*[1] * max_dim,
*ret.shape[sum_dim[0] + 1 :],
)
# iteratively fancy index
for a, i, sd in zip(arange, idx, sum_dim):
ret = (a == i).mul(ret).sum(sd)
# special permute case
if (
dim[0] != 0
and len(dim) != 1
and dim != list(range(dim[0], dim[-1] + 1))
):
ret_dims = list(range(ret.ndim))
ret = ret.permute(
ret_dims[dim[0] : dim[0] + max_dim]
+ ret_dims[: dim[0]]
+ ret_dims[dim[0] + max_dim :]
)
return ret
def __setitem__(self, indices, v):
"""
Set item in tensor.
Parameters:
indices (Sequence): Indices to set.
v (float): Value to set at given indices.
Returns:
Tensor: Tensor with value set at given indices.
"""
return self.__getitem__(indices).assign(v)
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(
self, arg: Sequence[Optional[Tuple[int, sint]]], value: float = 0
) -> Tensor:
"""
Slice tensor.
Parameters:
arg (Sequence): Sequence of tuples or None for slicing.
value (float): Value to pad with, default is 0.
Returns:
Tensor: Sliced tensor.
"""
arg_ = tuple([a if a is not None else (0, s) for s, a in zip(self.shape, arg)])
padding = tuple(
[(max(0, -p[0]), max(0, p[1] - self.shape[i])) for i, p in enumerate(arg_)]
)
return self.pad(padding, value=value).shrink(
tuple(
[
(p[0] + padding[i][0], p[1] + padding[i][0])
for i, p in enumerate(arg_)
]
)
)
def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor:
"""
Gather tensor along dimension.
Parameters:
idx (Tensor): Index tensor for gathering.
dim (int): Dimension to gather along.
Returns:
Tensor: Gathered tensor.
Attributes:
self (Tensor): Input tensor for gather operation.
idx (Tensor): Index tensor for gathering. Must have the same number of dimensions as self.
dim (int): Dimension to gather along.
Note:
AssertionError will be raised if idx.ndim != self.ndim, i.e., if the index tensor does not have the same
number of dimensions as the input tensor.
"""
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(
s >= i for s, i in zip(self.shape, idx.shape)
), "all dim of idx.shape must be smaller than self.shape"
if dim < 0:
dim += self.ndim
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = (
permarg[1:dim] + [permarg[0]] + permarg[dim + 1 :] + [permarg[dim]]
if dim != 0
else permarg[1:] + [permarg[0]]
)
return (
(
(
idx
== Tensor.arange(
self.shape[dim],
dtype=dtypes.int32,
requires_grad=False,
device=self.device,
)
)
* self.permute(*permarg)
.shrink(
tuple([*[(0, sh) for sh in idx.shape[1:-1]], (0, self.shape[dim])])
)
.unsqueeze(0)
)
.sum(-1)
.transpose(ax1=0, ax2=dim)
)
def cat(self, *args: Tensor, dim: int = 0) -> Tensor:
"""
Concatenate tensors along a given dimension.
This method concatenates the tensor `self` with other tensors in `*args` along the specified dimension `dim`. The tensors must have the same shape except for the dimension along which they are being concatenated. If `dim` is negative, it counts from the right.
Attributes:
*args (Tensor): Variable length argument list of tensors to be concatenated.
dim (int): The dimension along which the tensors will be concatenated. Default is 0.
Returns:
Tensor: The result of the concatenation operation.
"""
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all(
len(y.shape) == len(self.shape)
and all(y.shape[i] == s for i, s in enumerate(self.shape) if i != dim)
for y in args
)
catargs = [self, *args]
assert all(
t.shape for t in catargs
), "zero-dimensional tensor cannot be concatenated"
shapes = [s.shape[dim] for s in catargs]
shape_cumsum = [0, *accumulate(shapes)]
slc: List[List[Tuple[sint, sint]]] = [
[(0, 0) for _ in self.shape] for _ in catargs
]
for shp, k, s in zip(shapes, shape_cumsum[:-1], slc):
s[dim] = (k, shape_cumsum[-1] - k - shp)
return reduce(
Tensor.__add__, [arg.pad(tuple(s)) for arg, s in zip(catargs, slc)]
)
@staticmethod
def stack(tensors: Sequence[Tensor], dim: int = 0) -> Tensor:
"""
Stacks a sequence of tensors along the specified dimension.
This method takes a sequence of tensors and concatenates them along the specified dimension.
The first tensor in the sequence is unsqueezed on the specified dimension. Then, all other tensors
in the sequence are also unsqueezed on the specified dimension and concatenated with the first tensor.
Attributes:
tensors (Sequence[Tensor]): A sequence of tensors to stack.
dim (int, optional): The dimension along which to stack the tensors. Defaults to 0.
Returns:
Tensor: The stacked tensor.
"""
first = tensors[0].unsqueeze(dim)
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
# checks for shapes and number of dimensions delegated to cat
return first.cat(*unsqueezed_tensors, dim=dim)
def repeat(self, repeats: Sequence[int]) -> Tensor:
"""
Repeats this tensor along each dimension by the specified amounts.
Generates a new tensor which is a repetition of this tensor along each dimension. The number of
repetitions for each dimension is defined by the `repeats` argument.
Attributes:
repeats (Sequence[int]): The number of repetitions for each dimension.
Returns:
Tensor: The repeated tensor.
"""
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
final_shape = [r * s for r, s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def chunk(self, num: int, dim: int = 0) -> List[Tensor]:
"""
Splits this tensor into a specific number of chunks along the specified dimension.
Divides this tensor into a specific number of parts along the specified dimension. The tensor is
divided into approximately equal parts, with the last part being potentially smaller if the tensor's
size along the given dimension is not divisible by `num`.
Attributes:
num (int): The number of chunks to split this tensor into.
dim (int, optional): The dimension along which to split the tensor. Defaults to 0.
Returns:
List[Tensor]: A list of tensors, resulting from splitting this tensor.
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(
self.shape[dim] / num
)
slice_params = [
[slice(None)] * dim + [slice(k, k + step)]
for k in range(0, self.shape[dim], step)
]
return [self[tuple(sl)] for sl in slice_params]
def squeeze(self, dim: Optional[int] = None) -> Tensor:
"""
Removes a dimension of size 1 from this tensor.
If `dim` is given, removes the specified dimension from this tensor if it has size 1. If `dim` is not
provided, removes all dimensions of size 1 from this tensor. If the specified dimension does not have
size 1, an error is raised.
Attributes:
dim (Optional[int], optional): The dimension to remove if it has size 1. Defaults to None.
Returns:
Tensor: The tensor with the removed dimensions of size 1.
"""
if dim is None:
return (
self
if 1 not in self.shape
else self.reshape(*[size for size in self.shape if size != 1])
)
if dim <= 0 and self.ndim == 0:
return self # This is to match PyTorch behavior
if not -self.ndim <= dim < self.ndim:
raise IndexError(
f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})"
)
if dim < 0:
dim += self.ndim
return (
self
if self.shape[dim] != 1
else self.reshape(
*[size for idx, size in enumerate(self.shape) if idx != dim]
)
)
def unsqueeze(self, dim: int) -> Tensor:
"""
Add a dimension to the tensor at the specified index.
Args:
self (Tensor): The input tensor.
dim (int): The index where the new dimension will be added.
Returns:
Tensor: The output tensor with the additional dimension.
Raises:
ValueError: If `dim` is not a valid index for the new dimension.
"""
if dim < 0:
dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(
self, padding: Union[List[int], Tuple[int, ...]], value: float = 0
) -> Tensor:
"""
Pad the 2D tensor with specified values.
Args:
self (Tensor): The input tensor.
padding (Union[List[int], Tuple[int, ...]]): A sequence of integers
representing the padding values for each side of the tensor.
The order is (padding_left, padding_right, padding_top, padding_bottom).
value (float): The padding value, defaults to 0.
Returns:
Tensor: The output padded tensor.
"""
slc = [
(-p0, s + p1)
for p0, p1, s in zip(padding[::2], padding[1::2], self.shape[::-1])
][::-1]
return self.slice(
[(0, s) for s in self.shape[: -(len(padding) // 2)]] + slc, value=value
)
@property
def T(self) -> Tensor:
"""
Returns the transpose of the tensor.
Attributes:
self (Tensor): The input tensor.
Returns:
Tensor: The transposed tensor.
"""
return self.transpose()
def transpose(self, ax1=1, ax2=0) -> Tensor:
"""
Transposes the tensor along the specified axes.
Attributes:
self (Tensor): The input tensor.
ax1 (int): The first axis to be transposed. Default is 1.
ax2 (int): The second axis to be transposed. Default is 0.
Returns:
Tensor: The transposed tensor.
"""
order = list(range(len(self.shape)))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0):
"""
Flattens the tensor from the specified dimension.
Attributes:
self (Tensor): The input tensor.
start_dim (int): The starting dimension to flatten from. Default is 0.
Returns:
Tensor: The flattened tensor.
"""
return self.reshape(shape=self.shape[:start_dim] + (-1,))
# ***** reduce ops *****
def _reduce(
self,
fxn: Type[Function],
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim=False,
) -> Tensor:
"""
Applies a reduction function along the specified axes.
Attributes:
self (Tensor): The input tensor.
fxn (Function): The reduction function to apply.
axis (Optional[Union[int, Tuple[int, ...]]]): The axis/axes along which to perform the reduction. Default is None.
keepdim (bool): Whether or not to retain the reduced dimensions. Default is False.
Returns:
Tensor: The tensor with the reduction function applied.
"""
axis_: List[int] = (
list(range(len(self.shape)))
if axis is None
else ([axis] if isinstance(axis, int) else list(axis))
)
axis_ = [x if x >= 0 else x + len(self.shape) for x in axis_]
shape = tuple(s for i, s in enumerate(self.shape) if i not in axis_)
if 0 in self.shape and 0 not in shape:
return Tensor.full(
tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape,
{mlops.Sum: 0, mlops.Max: -float("inf")}[fxn],
)
ret = fxn.apply(
self,
new_shape=tuple([1 if i in axis_ else s for i, s in enumerate(self.shape)]),
)
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False):
"""
Computes the sum of the tensor along the specified axes.
Attributes:
self (Tensor): The input tensor.
axis (Optional[Union[int, Tuple[int, ...]]]): The axis/axes along which to compute the sum. Default is None.
keepdim (bool): Whether or not to retain the reduced dimensions. Default is False.
Returns:
Tensor: The tensor with the sum computed.
"""
return self._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False):
"""
Compute the maximum value along a given axis.
This method computes the maximum value of the elements in the input tensor along the specified axis.
By default, it computes the maximum value of the flattened tensor.
Args:
self (Tensor): The input tensor.
axis (int, optional): Axis along which to operate. Default is None, which means the function will
compute the maximum value of the flattened tensor.
keepdim (bool, optional): Whether to retain the original dimension. Default is False.
Returns:
Tensor: The output tensor containing the maximum values.
Examples:
>>> a = Tensor([[1, 2], [3, 4]])
>>> a.max()
Tensor(4)
>>> a.max(axis=0)
Tensor([3, 4])
>>> a.max(axis=1)
Tensor([2, 4])
"""
return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False):
"""
Compute the minimum value along a given axis.
This method computes the minimum value of the elements in the input tensor along the specified axis.
By default, it computes the minimum value of the flattened tensor.
Args:
self (Tensor): The input tensor.
axis (int, optional): Axis along which to operate. Default is None, which means the function will
compute the minimum value of the flattened tensor.
keepdim (bool, optional): Whether to retain the original dimension. Default is False.
Returns:
Tensor: The output tensor containing the minimum values.
Examples:
>>> a = Tensor([[1, 2], [3, 4]])
>>> a.min()
Tensor(1)
>>> a.min(axis=0)
Tensor([1, 2])
>>> a.min(axis=1)
Tensor([1, 3])
"""
return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
"""
Compute the mean value along a given axis.
This method computes the average of the elements in the input tensor along the specified axis.
By default, it computes the average of the flattened tensor.
Args:
self (Tensor): The input tensor.
axis (int, optional): Axis along which to operate. Default is None, which means the function will
compute the average of the flattened tensor.
keepdim (bool, optional): Whether to retain the original dimension. Default is False.
Returns:
Tensor: The output tensor containing the mean values.
Examples:
>>> a = Tensor([[1, 2], [3, 4]])
>>> a.mean()
Tensor(2.5)
>>> a.mean(axis=0)
Tensor([2, 3])
>>> a.mean(axis=1)
Tensor([1.5, 3.5])
"""
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return (
out.mul(prod(out.shape) / prod(self.shape)) if 0 not in self.shape else out
)
def std(self, axis=None, keepdim=False, correction=1):
"""
Calculate the standard deviation of the tensor.
This method computes the standard deviation along a given axis. The standard deviation is calculated as
sqrt((X - mean)**2 / N), where X is the tensor, mean is the mean value of X, and N is the number of elements
in X or the number of elements in the output if keepdim is True. If correction is 0, then the divisor used
in the calculation is N, otherwise it is N - 1. The default behavior is to use a correction term (N - 1).
:param axis: int or None, optional, default=None
Axis along which the standard deviation is calculated. If None, compute the standard deviation of the
flattened tensor.
:param keepdim: bool, optional, default=False
If this is set to True, the axes which are reduced are left in the result as dimensions with size one.
With this option, the result will broadcast correctly against the input array.
:param correction: int, optional, default=1
This parameter is used to decide whether to use Bessel's correction (correction = 1), or not
(correction = 0). Default value is 1.
Attributes:
self: Tensor
The input tensor.
:return: Tensor
The standard deviation of the tensor along the given axis.
"""
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(
axis=axis, keepdim=keepdim
)
return square_sum.div(
prod(self.shape) / prod(square_sum.shape) - correction
).sqrt()
def _softmax(self, axis):
"""
Calculate the softmax of the tensor along a given axis and return intermediate results for further use.
The softmax function is often used in deep learning models such as neural networks. It converts an input
tensor into a probability distribution where all values are between 0 and 1, and the sum of all values is
equal to 1. This method also returns intermediate results (m, e) for further use.
:param axis: int
Axis along which the softmax operation is performed.
Attributes:
self: Tensor
The input tensor.
:return: Tuple of three tensors: m, e, and ss
m is the max value along the given axis
e is the exponential values
ss is the sum of e
"""
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1):
"""
Calculate the softmax of the tensor along a given axis.
The softmax function is often used in deep learning models such as neural networks. It converts an input
tensor into a probability distribution where all values are between 0 and 1, and the sum of all values is
equal to 1.
:param axis: int, optional, default=-1
Axis along which the softmax operation is performed.
Attributes:
self: Tensor
The input tensor.
:return: Tensor
The softmax values of the tensor along the given axis.
"""
_, e, ss = self._softmax(axis)
return e.div(ss)
def log_softmax(self, axis=-1):
"""
Calculate the log softmax of a tensor along a specified axis.
This method computes the logarithm of the softmax values of a tensor's elements along a given axis.
The softmax function is a function that turns a vector of numbers into a probability distribution,
so that the elements of the vector add up to 1.
:param self: Tensor object on which operation is being performed.
:type self: Tensor
:param axis: The axis along which the log softmax will be computed. Default value is -1,
meaning the last dimension.
:type axis: int, optional
Attributes:
m (Tensor): The softmax values of the tensor elements along the specified axis.
ss (Tensor): The sum of the softmax values along the specified axis.
:return: A new tensor containing the log softmax values.
"""
m, _, ss = self._softmax(axis)
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
"""
Returns the indices of the maximum value along a specified axis.
This method computes the index locations of the maximum values of a tensor's elements along a given axis.
:param self: Tensor object on which operation is being performed.
:type self: Tensor
:param axis: The axis along which the argmax will be computed. Default value is None,
meaning the flattened input tensor is used.
:type axis: int, optional
:param keepdim: If set to True, the output tensor will have the same number of dimensions as the input tensor.
Default value is False.
:type keepdim: bool, optional
Attributes:
idx (Tensor): The tensor containing the index locations of maximum values along the specified axis.
:return: A new tensor containing the indices of the maximum values.
"""
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(
prod(self.shape) - 1,
-1,
-1,
dtype=dtypes.int32,
requires_grad=False,
device=self.device,
).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(
self.shape[axis] - 1,
-1,
-1,
dtype=dtypes.int32,
requires_grad=False,
device=self.device,
).reshape(self.shape[axis], *[1] * (self.ndim - axis - 1))
return self.shape[axis] - idx.max(axis=axis, keepdim=keepdim) - 1
def argmin(self, axis=None, keepdim=False):
"""
Returns the indices of the minimum value along a specified axis.
This method computes the index locations of the minimum values of a tensor's elements along a given axis.
:param self: Tensor object on which operation is being performed.
:type self: Tensor
:param axis: The axis along which the argmin will be computed. Default value is None,
meaning the flattened input tensor is used.
:type axis: int, optional
:param keepdim: If set to True, the output tensor will have the same number of dimensions as the input tensor.
Default value is False.
:type keepdim: bool, optional
Attributes:
idx (Tensor): The tensor containing the index locations of minimum values along the specified axis.
:return: A new tensor containing the indices of the minimum values.
"""
return (-self).argmax(axis=axis, keepdim=keepdim)
# ***** processing ops *****
def _pool(
self,
k_: Tuple[sint, ...],
stride: Union[Tuple[int, ...], int] = 1,
dilation: Union[Tuple[int, ...], int] = 1,
) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape) and all_int(
k_
), f"does not support symbolic {self.shape=}, {k_=}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(
d_
), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
slc_prefix, prefix, i_ = (
[(0, x) for x in self.shape[0 : -len(k_)]],
self.shape[0 : -len(k_)],
self.shape[-len(k_) :],
)
if any(k > s for k, s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k - 1) - 1) // s + 1 for i, d, k, s in zip(i_, d_, k_, s_)]
e_ = [
math.ceil(k * (i + d) / i) for k, i, d in zip(k_, i_, d_)
] # expands such that we don't need padding
xup = (
self.reshape(*prefix, *flatten((1, i) for i in i_))
.expand(*prefix, *flatten((e, i) for e, i in zip(e_, i_)))
.reshape(*prefix, *[e * i for e, i in zip(e_, i_)])
)
# slide by dilation
xup = xup.slice(
slc_prefix + [(0, k * (i + d)) for k, i, d in zip(k_, i_, d_)]
)
xup = xup.reshape(
*prefix, *flatten((k, i + d) for k, i, d in zip(k_, i_, d_))
)
xup = xup.slice(
slc_prefix
+ flatten(((0, k), (0, o * s)) for k, o, s in zip(k_, o_, s_))
)
# handle stride, and permute to move reduce to the end
xup = xup.reshape(
*prefix, *flatten((k, o, s) for k, o, s in zip(k_, o_, s_))
)
xup = xup.slice(
slc_prefix + flatten(((0, k), (0, o), (0, 1)) for k, o in zip(k_, o_))
)
xup = xup.reshape(*prefix, *flatten((k, o) for k, o in zip(k_, o_)))
return xup.permute(
*range(len(prefix)),
*[len(prefix) + i * 2 + 1 for i in range(len(k_))],
*[len(prefix) + i * 2 for i in range(len(k_))],
)
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i + (s - k)) // s for i, s, k in zip(i_, s_, k_)]
xup = self.slice(slc_prefix + [(0, o * s) for o, s in zip(o_, s_)])
xup = xup.reshape(*prefix, *flatten(((o, s) for o, s in zip(o_, s_))))
xup = xup.slice(slc_prefix + flatten(((0, o), (0, k)) for o, k in zip(o_, k_)))
return xup.permute(
*range(len(prefix)),
*[len(prefix) + i * 2 for i in range(len(k_))],
*[len(prefix) + i * 2 + 1 for i in range(len(k_))],
)
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2, 2), stride=None, dilation=1):
"""
Perform an average pooling operation on the input tensor.
Attributes:
kernel_size (tuple): The size of the sliding window for each dimension of the input tensor. Default is (2, 2).
stride (tuple or None): The stride of the sliding window for each dimension of the input tensor. If not provided, it defaults to be the same as kernel_size.
dilation (int): The spacing between the kernel points. Default is 1.
Returns:
Tensor: The average pooled tensor.
"""
return self._pool(
make_pair(kernel_size),
stride if stride is not None else kernel_size,
dilation,
).mean(axis=tuple(range(0 - len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2, 2), stride=None, dilation=1):
"""
Perform a max pooling operation on the input tensor.
Attributes:
kernel_size (tuple): The size of the sliding window for each dimension of the input tensor. Default is (2, 2).
stride (tuple or None): The stride of the sliding window for each dimension of the input tensor. If not provided, it defaults to be the same as kernel_size.
dilation (int): The spacing between the kernel points. Default is 1.
Returns:
Tensor: The max pooled tensor.
"""
return self._pool(
make_pair(kernel_size),
stride if stride is not None else kernel_size,
dilation,
).max(axis=tuple(range(0 - len(make_pair(kernel_size)), 0)))
def conv_transpose2d(
self,
weight: Tensor,
bias: Optional[Tensor] = None,
groups=1,
stride=1,
dilation=1,
padding=0,
output_padding=0,
) -> Tensor:
"""
Compute the 2D transposed convolution of input tensor with the specified weight tensor.
Attributes:
self (Tensor): Input tensor.
weight (Tensor): Weight tensor.
bias (Optional[Tensor]): Bias tensor, if used. Default is None.
groups (int): Number of groups for the convolution. Default is 1.
stride (int or tuple): Stride of the convolution. Default is 1.
dilation (int or tuple): Spacing between the kernel elements. Default is 1.
padding (int or tuple): Padding added to both sides of the input. Default is 0.
output_padding (int or tuple): Additional size added to one side of the output shape. Default is 0.
Returns:
Tensor: Output tensor after transposed convolution operation.
"""
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape) + 1))
x, w = self, weight.reshape(
groups, weight.shape[0] // groups, weight.shape[1], *weight.shape[2:]
).permute(0, 2, 1, *trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s > 1 for s in stride):
x = x.reshape(*x.shape[:2], *flatten((k, 1) for k in x.shape[2:]))
x = x.pad(((0, 0), (0, 0), *flatten(((0, 0), (0, s - 1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k * s for k, s in zip(x.shape[2::2], stride)])
x = x.shrink(
(
(0, x.shape[0]),
(0, x.shape[1]),
*[(0, k - (s - 1)) for k, s in zip(x.shape[2:], stride)],
)
)
padding = flatten(
(
((k - 1) * d - p, (k - 1) * d - p + op)
for k, d, p, op in reversed(
list(
zip(
HW,
make_pair(dilation, len(HW)),
make_pair(padding, len(HW)),
make_pair(output_padding, len(HW)),
)
)
)
)
)
return x.conv2d(
w.reshape(w.shape[0] * w.shape[1], *w.shape[2:]),
groups=groups,
bias=bias,
dilation=dilation,
padding=padding,
)
wino = int(getenv("WINO", "0"))
def conv2d(
self,
weight: Tensor,
bias: Optional[Tensor] = None,
groups=1,
stride=1,
dilation=1,
padding=0,
) -> Tensor:
"""
Perform a 2D convolution operation on the input tensor.
This function convolves the input tensor with the given weight tensor and optionally adds the bias term. The convolution operation supports various optional parameters such as groups, stride, dilation, and padding.
Args:
self (Tensor): The input tensor.
weight (Tensor): The weight tensor for the convolution operation.
bias (Optional[Tensor]): An optional bias term to be added after the convolution operation. Default is None.
groups (int): Number of groups in which the input and output channels are divided. Default is 1.
stride (int or tuple): Stride of the 2D convolution operation. Default is 1.
dilation (int or tuple): Dilation factor of the 2D convolution operation. Default is 1.
padding (int or tuple): Padding for the 2D convolution operation. Default is 0.
Returns:
Tensor: The output tensor after performing the convolution operation, optionally adding the bias term if provided.
Raises:
ValueError: If the shape of the input tensor does not match the shape of the weight tensor or if the padding length is incorrect.
Note:
This function assumes that the input tensor has a shape (batch_size, channels_in, height, width) and the weight tensor has a shape (channels_out, channels_in // groups, kernel_height, kernel_width).
"""
(bs, cin_), (cout, cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups * cin == cin_ and len(self.shape) == len(
weight.shape
), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
if isinstance(padding, (tuple, list)):
assert len(padding) == 2 * len(HW) or len(padding) == len(
HW
), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}"
padding_ = (
[padding] * 2 * len(HW)
if isinstance(padding, int)
else (
padding
if len(padding) == 2 * len(HW)
else [p for p in padding for _ in range(2)][::-1]
)
)
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool(
HW, stride, dilation
) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout // groups, x.shape[2 : -len(HW)]
if (
not all(x == 3 for x in HW)
or stride != 1
or dilation != 1
or not Tensor.wino
):
# normal conv
x = (
x.reshape(bs, groups, cin, 1, *oyx, *HW)
.expand(bs, groups, cin, rcout, *oyx, *HW)
.permute(
0,
1,
3,
*[4 + i for i in range(len(oyx))],
2,
*[4 + len(oyx) + i for i in range(len(HW))],
)
)
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (
(x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW))
.sum([-1 - i for i in range(1 + len(oyx))], keepdim=True)
.reshape(bs, cout, *oyx)
)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dim=0):
"""
Apply a 3x3 matrix to a 4x4 matrix in Winograd's F(4x4,3x3) algorithm.
This method is used for applying a 3x3 matrix to a 4x4 matrix as part of the Winograd F(4x4,3x3) convolution algorithm. The function recursively applies the transformation until it reaches the specified dimension.
:param mat: A list of lists representing the 3x3 matrix.
:type mat: List[List[int]]
:param t: A tensor to which the matrix will be applied.
:type t: Tensor
:param dim: The dimension to which the transformation is applied, defaults to 0.
:type dim: int, optional
:return: The transformed tensor after applying the matrix.
:rtype: Tensor
Attributes:
HWI (tuple): A tuple representing the input size of Winograd's F(4x4,3x3) algorithm. Default is (6,).
HWO (tuple): A tuple representing the output size of Winograd's F(4x4,3x3) algorithm. Default is (4,).
"""
return (
t
if dim == len(HW)
else Tensor.stack(
[
apply_matrix(
mat,
sum(mm * t[j] for j, mm in enumerate(m) if mm),
dim=dim + 1,
)
for m in mat
]
)
)
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [
[4, 0, -5, 0, 1, 0],
[0, -4, -4, 1, 1, 0],
[0, 4, -4, -1, 1, 0],
[0, -2, -1, 2, 1, 0],
[0, 2, -1, -2, 1, 0],
[0, 4, 0, -5, 0, 1],
]
winograd_G = [
[1 / 4, 0, 0],
[-1 / 6, -1 / 6, -1 / 6],
[-1 / 6, 1 / 6, -1 / 6],
[1 / 24, 1 / 12, 1 / 6],
[1 / 24, -1 / 12, 1 / 6],
[0, 0, 1],
]
winograd_At = [
[1, 1, 1, 1, 1, 0],
[0, 1, -1, 2, -2, 0],
[0, 1, 1, 4, 4, 0],
[0, 1, -1, 8, -8, 1],
] # applying At in pre-order almost doubles compilation time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
d = self.pad2d(
sum(
[
[
padding_[i * 2],
padding_[i * 2 + 1]
+ (-(dim + sum(padding_[i * 2 : (i + 1) * 2]) - 2) % 4),
]
for i, dim in enumerate(self.shape[-len(HW) :])
],
[],
)
)._pool(
HWI, HWO
) # (bs, cin_, tyx, HWI)
d = d.permute(
*range(len(d.shape) - len(HW), len(d.shape)), *range(len(d.shape) - len(HW))
).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
tyx = d.shape[-len(HWI) :] # dim of tiling
g = weight.permute(
*range(len(weight.shape) - len(HW), len(weight.shape)),
*range(len(weight.shape) - len(HW)),
) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
gfactors = (
apply_matrix(winograd_G, g)
.contiguous()
.reshape(*HWI, 1, groups, rcout, cin, *([1] * len(tyx)))
) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
dfactors = (
apply_matrix(winograd_Bt, d)
.contiguous()
.reshape(*HWI, bs, groups, 1, cin, *tyx)
) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
ret = apply_matrix(
winograd_At, (gfactors * dfactors).sum(axis=-1 - len(HW))
) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = ret.permute(
[
*range(len(HW), len(ret.shape) - len(HW)),
*[i + o for i in range(len(HW)) for o in [len(ret.shape) - len(HW), 0]],
]
) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(
tuple((0, s) for s in [bs, cout, *oyx])
) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
return (
(
ret
if bias is None
else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))
)
.contiguous()
.contiguous_backward()
)
def dot(self, w: Tensor) -> Tensor:
"""
Perform a dot product operation between this tensor and another tensor `w`.
The tensors should be at least 1D, and the last dimension of this tensor must match the second-to-last or last dimension of `w`.
Args:
self (Tensor): This tensor.
w (Tensor): The other tensor.
Returns:
Tensor: The result of the dot product operation.
"""
n1, n2 = len(self.shape), len(w.shape)
assert (
n1 != 0 and n2 != 0
), f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert (
self.shape[-1] == w.shape[-min(n2, 2)]
), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})"
x = self.reshape(
*self.shape[0:-1], *[1] * min(n1 - 1, n2 - 1, 1), self.shape[-1]
)
w = w.reshape(
*w.shape[0:-2], *[1] * min(n1 - 1, n2 - 1, 1), *w.shape[-min(n2, 2) :]
).transpose(-1, -min(n2, 2))
return (x * w).sum(-1)
def _cumsum(self, axis: int = 0, _first_zero=False) -> Tensor:
"""
Calculate the cumulative sum of this tensor along a specified axis.
Args:
self (Tensor): This tensor.
axis (int): The axis along which to calculate the cumulative sum. Default is 0.
Returns:
Tensor: The result of the cumulative sum operation.
"""
return (
self.transpose(axis, -1)
.pad2d((self.shape[axis] - int(not _first_zero), 0))
._pool((self.shape[axis],))
.sum(-1)
.transpose(axis, -1)
)
def cumsum(self, axis: int = 0) -> Tensor:
"""
Calculate the cumulative sum of this tensor along a specified axis.
The implementation uses a two-stage approach for large tensors.
Args:
self (Tensor): This tensor.
axis (int): The axis along which to calculate the cumulative sum. Default is 0.
Returns:
Tensor: The result of the cumulative sum operation.
"""
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if self.shape[axis] <= SPLIT * 2:
return self._cumsum(axis)
ret = self.transpose(axis, -1).pad2d(
(round_up(self.shape[axis], SPLIT) - self.shape[axis], 0)
)
ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1] // SPLIT, SPLIT)._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x: Tensor):
"""
Fix tensor by reshaping and transposing it.
This function takes a tensor x as input, reshapes it based on the dimensions of 'ret'
except for the last two dimensions, multiplies these dimensions together with the product
of the last two dimensions of 'ret', and finally transposes the tensor based on the axis
dimension.
:param x: The input tensor to be fixed.
:type x: Tensor
:return: The reshaped, sliced, and transposed tensor.
"""
return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[
..., -self.shape[axis] :
].transpose(axis, -1)
return fix(ret) + fix(base_add)
@staticmethod
def _tri(r: int, c: int, k: int = 0, **kwargs) -> Tensor:
"""
Create a triangular matrix.
This method creates a triangular matrix with the specified number of rows and columns, offset by `k`.
The resulting matrix is unsqueezed in both dimensions and expanded to match the desired size.
:param r: Number of rows.
:type r: int
:param c: Number of columns.
:type c: int
:param k: Diagonal offset (default=0).
:type k: int, optional
:return: Tensor representing the triangular matrix.
:rtype: Tensor
"""
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r, c) <= Tensor.arange(
-k, c - k, **kwargs
).unsqueeze(0).expand(r, c)
def triu(self, k: int = 0) -> Tensor:
"""
Create a new tensor with all elements below the `k`-th diagonal set to zero.
This method creates a new tensor with all elements below the `k`-th diagonal set to zero in the upper triangular matrix.
The resulting tensor is created by applying the `_tri` function and using the where() function to combine it with the original tensor.
:param k: Diagonal offset (default=0).
:type k: int, optional
:return: Tensor representing the upper triangular matrix.
:rtype: Tensor
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(
self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device
).where(self, Tensor.zeros_like(self))
def tril(self, k: int = 0) -> Tensor:
"""
Create a new tensor with all elements above the `k`-th diagonal set to zero.
This method creates a new tensor with all elements above the `k`-th diagonal set to zero in the lower triangular matrix.
The resulting tensor is created by applying the `_tri` function and using the where() function to combine it with the original tensor.
:param k: Diagonal offset (default=0).
:type k: int, optional
:return: Tensor representing the lower triangular matrix.
:rtype: Tensor
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(
self.shape[-2],
self.shape[-1],
k=k + 1,
dtype=self.dtype,
device=self.device,
).where(Tensor.zeros_like(self), self)
# ***** mlops (unary) *****
def neg(self):
"""
Negate the value of each element in the tensor.
Returns:
Tensor: A new tensor with all elements negated.
"""
return mlops.Neg.apply(self)
def contiguous(self):
"""
Ensure that the storage of the tensor is contiguous in memory.
Returns:
Tensor: A new tensor with contiguous storage.
"""
return mlops.Contiguous.apply(self)
def contiguous_backward(self):
"""
Ensure that the gradient of the tensor is contiguous in memory.
Returns:
Tensor: A new tensor with contiguous gradient storage.
"""
return mlops.ContiguousBackward.apply(self)
def log(self):
"""
Calculate the natural logarithm of each element in the tensor.
Returns:
Tensor: A new tensor with the natural logarithm of each element.
"""
return mlops.Log.apply(self)
def log2(self):
"""
Calculate the base-2 logarithm of each element in the tensor.
Returns:
Tensor: A new tensor with the base-2 logarithm of each element.
"""
return mlops.Log.apply(self) / math.log(2)
def exp(self):
"""
Calculate the exponential of each element in the tensor.
Returns:
Tensor: A new tensor with the exponential of each element.
"""
return mlops.Exp.apply(self)
def exp2(self):
"""
Calculate the base 2 exponential of the current object.
Attributes:
self (object): The instance of the class.
Returns:
float: The base 2 exponential of the current object.
"""
return mlops.Exp.apply(self * math.log(2))
def relu(self):
"""
Apply the Rectified Linear Unit (ReLU) function to the current object.
The ReLU function is defined as f(x) = max(0, x),
which returns 0 if the input value x is negative and x if it is positive.
Attributes:
self (object): The instance of the class.
Returns:
float: The ReLU function applied to the current object.
"""
return mlops.Relu.apply(self)
def sigmoid(self):
"""
Apply the Sigmoid function to the current object.
The Sigmoid function is defined as f(x) = 1 / (1 + exp(-x)),
which maps any input real number into a value between 0 and 1.
Attributes:
self (object): The instance of the class.
Returns:
float: The Sigmoid function applied to the current object.
"""
return mlops.Sigmoid.apply(self)
def sin(self):
"""
Calculate the sine of the current object.
Attributes:
self (object): The instance of the class.
Returns:
float: The sine of the current object.
"""
return mlops.Sin.apply(self)
def sqrt(self):
"""
Calculate the square root of the current object.
Attributes:
self (object): The instance of the class.
Returns:
float: The square root of the current object.
"""
return mlops.Sqrt.apply(self)
def rsqrt(self):
"""
Calculate the reciprocal square root of this object.
Returns:
The reciprocal square root of this object.
"""
return (1 / self).sqrt()
def cos(self):
"""
Calculate the cosine of this object interpreted as an angle.
The angle is calculated by subtracting it from pi/2.
Returns:
The cosine of this object's angle.
"""
return ((math.pi / 2) - self).sin()
def tan(self):
"""
Calculate the tangent of this object interpreted as an angle.
The tangent is calculated by dividing the sine by the cosine of this object's angle.
Returns:
The tangent of this object's angle.
"""
return self.sin() / self.cos()
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor:
"""
Truncate the tensor.
Casts the tensor to an int32 data type, ensures it is contiguous, and then casts it back to its original data type.
Attributes:
self (Tensor): The input tensor.
Returns:
Tensor: The truncated tensor.
"""
return self.cast(dtypes.int32).contiguous().cast(self.dtype)
def ceil(self: Tensor) -> Tensor:
"""
Round up the tensor to the nearest integer.
Compares the tensor with its truncated version, and if greater, adds 1; otherwise, returns the original truncated value.
Attributes:
self (Tensor): The input tensor.
Returns:
Tensor: The rounded-up tensor.
"""
return (self > (b := self.trunc())).where(b + 1, b)
def floor(self: Tensor) -> Tensor:
"""
Round down the tensor to the nearest integer.
Compares the tensor with its truncated version, and if smaller, subtracts 1; otherwise, returns the original truncated value.
Attributes:
self (Tensor): The input tensor.
Returns:
Tensor: The rounded-down tensor.
"""
return (self < (b := self.trunc())).where(b - 1, b)
def square(self):
"""
Square each element in the tensor.
Attributes:
self (Tensor): The input tensor.
Returns:
Tensor: A new tensor with each element squared.
"""
return self * self
def clip(self, min_, max_):
"""
Clip the tensor to a specified range.
Attributes:
self (Tensor): The input tensor.
min_: The minimum value for clipping.
max_: The maximum value for clipping.
Returns:
Tensor: A new tensor with values clipped between `min_` and `max_`.
"""
return self.maximum(min_).minimum(max_)
def abs(self):
"""
Calculate the absolute value of an object.
Returns:
int: The absolute value of the object.
Raises:
ValueError: If the object does not support the absolute value computation.
"""
return self.relu() + (-self).relu()
def sign(self):
"""
Calculate and return the element-wise sign of the tensor.
For each element in the tensor, this function determines if it is positive or negative and assigns 1 to positive elements and -1 to negative elements. The result is returned as a new tensor with the same shape as the original tensor.
Returns:
torch.Tensor: A tensor of the same shape as the input tensor, where all positive elements are replaced with 1 and all negative elements are replaced with -1.
"""
return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
def reciprocal(self):
"""
Calculate and return the element-wise reciprocal of the tensor.
For each element in the tensor, this function calculates its reciprocal (1 divided by the element value). The result is returned as a new tensor with the same shape as the original tensor.
Returns:
torch.Tensor: A tensor of the same shape as the input tensor, where all elements are replaced with their respective reciprocals.
"""
return 1.0 / self
# ***** activation functions (unary) *****
def elu(self, alpha=1.0):
"""
Calculate the Exponential Linear Unit (ELU) activation function.
This method calculates the ELU function for each element in `self`. The ELU function is defined as:
f(x) = max(0, x) - alpha * exp(-x) if x <= 0
f(x) = x if x > 0
Parameters:
alpha (float): A scaling factor for the negative part of the function, default is 1.0.
Returns:
ndarray: The transformed array after applying the ELU function element-wise.
Attributes:
relu (method): A method that applies the Rectified Linear Unit (ReLU) function to the data in `self`. ReLU replaces all negative values with zero and keeps positive values unchanged.
exp (method): A method that computes the exponential of all elements in `self`. The exponential is applied element-wise.
"""
return self.relu() - alpha * (1 - self.exp()).relu()
def celu(self, alpha=1.0):
"""
Calculate the Continuously Differentiable Exponential Linear Unit (C-ELU) activation function.
This method calculates the C-ELU function for each element in `self`. The C-ELU function is defined as:
f(x) = max(0, x) + alpha * exp(-x / alpha) if x <= 0
f(x) = x if x > 0
Parameters:
alpha (float): A scaling factor for the negative part of the function, default is 1.0.
Returns:
ndarray: The transformed array after applying the C-ELU function element-wise.
Attributes:
maximum (method): A method that takes the element-wise maximum of `self` and another array or scalar.
exp (method): A method that computes the exponential of all elements in `self`. The exponential is applied element-wise.
minimum (method): A method that takes the element-wise minimum of `self` and another array or scalar.
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self):
"""
Calculate the Swish activation function.
This method calculates the Swish function for each element in `self`. The Swish function is defined as:
f(x) = x * sigmoid(x)
Returns:
ndarray: The transformed array after applying the Swish function element-wise.
Attributes:
sigmoid (method): A method that applies the Sigmoid function to the data in `self`. The Sigmoid function is defined as:
f(x) = 1 / (1 + exp(-x))
"""
return self * self.sigmoid()
def silu(self):
"""
Calculate the Sigmoid Weighted Linear Unit (SiLU) activation function, also known as the swish function.
This method calculates the SiLU function for each element in `self` using the Swish function. The SiLU function is defined as:
f(x) = x * sigmoid(x)
Returns:
ndarray: The transformed array after applying the SiLU function element-wise.
Attributes:
swish (method): A method that applies the Swish function to the data in `self`. The Swish function is defined as:
f(x) = x * sigmoid(x)
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self):
"""
Calculate the Rectified Linear Unit 6 (ReLU6) activation function.
This method calculates the ReLU6 function for each element in `self`. The ReLU6 function is defined as:
f(x) = min(max(0, x), 6)
Returns:
ndarray: The transformed array after applying the ReLU6 function element-wise.
Attributes:
relu (method): A method that applies the Rectified Linear Unit (ReLU) function to the data in `self`. The ReLU function is defined as:
f(x) = max(0, x)
"""
return self.relu() - (self - 6).relu()
def hardswish(self):
"""
Calculate the Hard Swish activation function.
This method calculates the Hard Swish function for each element in `self`. The Hard Swish function is defined as:
f(x) = x * (((x + 3) min 6) max 0) / 6
Returns:
ndarray: The transformed array after applying the Hard Swish function element-wise.
Attributes:
relu6 (method): A method that applies the Rectified Linear Unit 6 (ReLU6) function to the data in `self`. The ReLU6 function is defined as:
f(x) = min(max(0, x), 6)
"""
return self * (self + 3).relu6() * (1 / 6)
def tanh(self):
"""
Calculate the Hyperbolic Tangent (tanh) activation function.
This method calculates the tanh function for each element in `self`. The tanh function is defined as:
f(x) = 2 * sigmoid(2 * x) - 1
Returns:
ndarray: The transformed array after applying the tanh function element-wise.
Attributes:
sigmoid (method): A method that applies the Sigmoid function to the data in `self`. The Sigmoid function is defined as:
f(x) = 1 / (1 + exp(-x))
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self):
"""
Calculate the Hyperbolic Sine (sinh) activation function.
This method calculates the sinh function for each element in `self`. The sinh function is defined as:
f(x) = (exp(x) - exp(-x)) / 2
Returns:
ndarray: The transformed array after applying the sinh function element-wise.
Attributes:
exp (method): A method that applies the Exponential function to the data in `self`. The Exponential function is defined as:
f(x) = e^x
neg (method): A method that applies the Negation operation to the data in `self`. The Negation operation returns an element-wise negative of `self`.
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self):
"""
Calculate the Hyperbolic Cosine (cosh) activation function.
This method calculates the cosh function for each element in `self`. The cosh function is defined as:
f(x) = (exp(x) + exp(-x)) / 2
Returns:
ndarray: The transformed array after applying the cosh function element-wise.
Attributes:
exp (method): A method that applies the Exponential function to the data in `self`. The Exponential function is defined as:
f(x) = e^x
neg (method): A method that applies the Negation operation to the data in `self`. The Negation operation returns an element-wise negative of `self`.
"""
return (self.exp() + self.neg().exp()) / 2
def atanh(self):
"""
Calculate the Inverse Hyperbolic Tangent (atanh) activation function.
This method calculates the atanh function for each element in `self`. The atanh function is defined as:
f(x) = log((1 + x) / (1 - x)) / 2
Returns:
ndarray: The transformed array after applying the atanh function element-wise.
Attributes:
log (method): A method that applies the Natural Logarithm function to the data in `self`. The Natural Logarithm function is defined as:
f(x) = ln(x)
"""
return ((1 + self) / (1 - self)).log() / 2
def asinh(self):
"""
Calculate the Inverse Hyperbolic Sine (asinh) activation function.
This method calculates the asinh function for each element in `self`. The asinh function is defined as:
f(x) = log(x + sqrt(1 + x^2))
Returns:
ndarray: The transformed array after applying the asinh function element-wise.
Attributes:
log (method): A method that applies the Natural Logarithm function to the data in `self`. The Natural Logarithm function is defined as:
f(x) = ln(x)
square (method): A method that squares each element in `self`. The Square operation returns an element-wise square of `self`.
sqrt (method): A method that applies the Square Root function to the data in `self`. The Square Root function is defined as:
f(x) = sqrt(x)
"""
return (self + (self.square() + 1).sqrt()).log()
def acosh(self):
"""
Calculate the Inverse Hyperbolic Cosine (acosh) activation function.
This method calculates the acosh function for each element in `self`. The acosh function is defined as:
f(x) = log(x + sqrt((x - 1)(x + 1)))
Returns:
ndarray: The transformed array after applying the acosh function element-wise.
Attributes:
log (method): A method that applies the Natural Logarithm function to the data in `self`. The Natural Logarithm function is defined as:
f(x) = ln(x)
square (method): A method that squares each element in `self`. The Square operation returns an element-wise square of `self`.
sqrt (method): A method that applies the Square Root function to the data in `self`. The Square Root function is defined as:
f(x) = sqrt(x)
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1):
"""
Apply the HardTanh activation function.
This method applies the HardTanh function to each element in `self`. The HardTanh function is defined as:
f(x) = max_val if x > max_val
= min_val if x < min_val
= x otherwise
Args:
min_val (float): The minimum value of the output range. Defaults to -1.
max_val (float): The maximum value of the output range. Defaults to 1.
Returns:
ndarray: The transformed array after applying the HardTanh function element-wise.
Attributes:
clip(method): A method that clips `self` to a specified range [min_val, max_val]. If an element in `self` is less than min_val, it is set to min_val.
If an element is greater than max_val, it is set to max_val. The clip operation does not modify elements that are within the range [min_val, max_val].
"""
return self.clip(min_val, max_val)
def gelu(self):
"""
Apply the Gaussian Error Linear Unit (GELU) activation function.
This method applies the GELU function to each element in `self`. The GELU function is defined as:
f(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
Returns:
ndarray: The transformed array after applying the GELU function element-wise.
Attributes:
tanh (method): A method that applies the Hyperbolic Tangent function to the data in `self`. The Hyperbolic Tangent function is defined as:
f(x) = tanh(x)
"""
return (
0.5
* self
* (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
)
def quick_gelu(self):
"""
Apply a faster approximation of Gaussian Error Linear Unit (GELU) activation function.
This method applies an approximate GELU function to each element in `self`. The approximate GELU function is defined as:
f(x) = x * sigmoid(x * 1.702)
Returns:
ndarray: The transformed array after applying the approximate GELU function element-wise.
Attributes:
sigmoid (method): A method that applies the Sigmoid function to the data in `self`. The Sigmoid function is defined as:
f(x) = 1 / (1 + exp(-x))
"""
return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01):
"""
Apply the Leaky ReLU activation function.
This method applies the Leaky ReLU function to each element in `self`. The Leaky ReLU function is defined as:
f(x) = max(x, neg_slope * x)
Args:
neg_slope (float): The negative slope parameter for the Leaky ReLU function. Default is 0.01.
Returns:
ndarray: The transformed array after applying the Leaky ReLU function element-wise.
Attributes:
relu (method): A method that applies the Rectified Linear Unit (ReLU) function to the data in `self`. The ReLU function is defined as:
f(x) = max(0, x)
"""
return self.relu() - (-neg_slope * self).relu()
def mish(self):
"""
Apply the Mish activation function.
This method applies the Mish function to each element in `self`. The Mish function is defined as:
f(x) = x * tanh(softplus(x))
Returns:
ndarray: The transformed array after applying the Mish function element-wise.
Attributes:
softplus (method): A method that applies the Softplus function to the data in `self`. The Softplus function is defined as:
f(x) = log(1 + exp(x))
tanh (method): A method that applies the hyperbolic tangent function to the data in `self`. The hyperbolic tangent function is defined as:
f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
"""
return self * self.softplus().tanh()
def softplus(self, beta=1):
"""
Apply the Softplus function.
This method applies the Softplus function to each element in `self`. The Softplus function is defined as:
f(x) = (1/beta) * log(1 + exp(beta * x))
Args:
beta (float): The beta parameter for the Softplus function. Default is 1.
Returns:
ndarray: The transformed array after applying the Softplus function element-wise.
"""
return (1 / beta) * (1 + (self * beta).exp()).log()
def softsign(self):
"""
Apply the Softsign function.
This method applies the Softsign function to each element in `self`. The Softsign function is defined as:
f(x) = x / (1 + |x|)
Returns:
ndarray: The transformed array after applying the Softsign function element-wise.
"""
return self / (1 + self.abs())
# ***** broadcasted binary mlops *****
def _broadcasted(
self, y: Union[Tensor, float], reverse: bool = False
) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
if 0 in x.shape:
return x, x.full_like(y)
y = Tensor(
y,
device=self.device,
requires_grad=False,
dtype=self.dtype
if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType
else dtypes.float32,
)
if reverse:
x, y = y, x
if (xshape := x.shape) == (yshape := y.shape):
return (x, y)
shape_delta = len(xshape) - len(yshape)
if shape_delta > 0:
y = y.reshape((1,) * shape_delta + yshape)
elif shape_delta < 0:
x = x.reshape((1,) * -shape_delta + xshape)
if (xshape := x.shape) == (yshape := y.shape):
return (x, y)
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
if xshape != shape_ret:
x = x.expand(shape_ret)
if yshape != shape_ret:
y = y.expand(shape_ret)
return (x, y)
def _to_float(self, x: Union[Tensor, float]):
return (
x.lazydata.base.op.arg
if isinstance(x, Tensor)
and x.lazydata.is_unrealized_contiguous_const()
and not x.requires_grad
and self._broadcasted(x)[0].shape == self.shape
else x
)
def add(self, x: Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return (
mlops.Add.apply(*self._broadcasted(x, reverse))
if x.__class__ is Tensor or x
else self
)
def sub(self, x: Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return (
mlops.Sub.apply(*self._broadcasted(x, reverse))
if x.__class__ is Tensor or x
else (-self if reverse else self)
)
def mul(self, x: Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
if x.__class__ is not Tensor and x == 0.0:
return mlops.Zero.apply(self)
if x.__class__ is not Tensor and x == -1.0:
return -self
return (
mlops.Mul.apply(*self._broadcasted(x, reverse))
if x.__class__ is Tensor or x != 1.0
else self
)
def div(self, x: Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
return (
mlops.Div.apply(*self._broadcasted(x, reverse))
if x.__class__ is Tensor
or reverse
or not x
or not dtypes.is_float(self.dtype)
else self.mul(1 / x)
)
def pow(self, x: Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
if x.__class__ is not Tensor and not reverse:
# simple pow identities
if x < 0:
return self.reciprocal().pow(-x)
if x == 3.0:
return self * self * self
if x == 2.0:
return self * self
if x == 1.0:
return self
if x == 0.5:
return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0:
return self.mul(math.log(x)).exp()
ar = (
self.abs().log().mul(x).exp()
if not reverse or isinstance(x, Tensor)
else self.mul(math.log(abs(x))).exp()
)
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (
(x * math.pi).cos()
if isinstance(x, Tensor)
else math.cos(x * math.pi)
if not reverse
else (self * math.pi).cos()
)
# we only need to correct the sign if the base is negative
base_sign = (
(
self.sign()
if not reverse
else x.sign()
if isinstance(x, Tensor)
else math.copysign(1, x)
)
- 1
) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (
1.5
* (
1
- (
self.sign().abs()
if not reverse
else x.sign().abs()
if isinstance(x, Tensor)
else abs(int(bool(x)))
)
)
)
# inject nan if the base is negative and the power is not an integer
to_nan = (
((x - x.trunc()) * 1e10).abs().clip(0, 1)
if isinstance(x, Tensor)
else int(bool(x - int(x)))
if not reverse
else ((self - self.trunc()) * 1e10).abs().clip(0, 1)
) * base_sign
inject_nan = (
((((-to_nan) * 2) + 1)).log().add(1)
if isinstance(to_nan, Tensor)
else 1
if not to_nan
else float("nan")
)
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def matmul(self, x: Tensor, reverse=False) -> Tensor:
return x.dot(self) if reverse else self.dot(x)
def maximum(self, x: Union[Tensor, float]) -> Tensor:
return (
(self < x)
.detach()
.where(x, (self > x).detach().where(self, (self + x) / 2))
)
def minimum(self, x: Union[Tensor, float]) -> Tensor:
return -((-self).maximum(-x))
def where(self: Tensor, input_: Union[Tensor, float], other: Union[Tensor, float]):
x_, y = self._broadcasted(input_)
x, z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))
# ***** op wrappers (wasted lines to make the typechecker happy) *****
def __neg__(self) -> Tensor:
return self.neg()
def __add__(self, x) -> Tensor:
return self.add(x)
def __sub__(self, x) -> Tensor:
return self.sub(x)
def __mul__(self, x) -> Tensor:
return self.mul(x)
def __pow__(self, x) -> Tensor:
return self.pow(x)
def __truediv__(self, x) -> Tensor:
return self.div(x)
def __matmul__(self, x) -> Tensor:
return self.matmul(x)
def __radd__(self, x) -> Tensor:
return self.add(x, True)
def __rsub__(self, x) -> Tensor:
return self.sub(x, True)
def __rmul__(self, x) -> Tensor:
return self.mul(x, True)
def __rpow__(self, x) -> Tensor:
return self.pow(x, True)
def __rtruediv__(self, x) -> Tensor:
return self.div(x, True)
def __rmatmul__(self, x) -> Tensor:
return self.matmul(x, True)
def __iadd__(self, x) -> Tensor:
return self.assign(self.add(x))
def __isub__(self, x) -> Tensor:
return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor:
return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor:
return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor:
return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor:
return self.assign(self.matmul(x))
def __lt__(self, x) -> Tensor:
return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor:
return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor:
return 1.0 - (self < x)
def __le__(self, x) -> Tensor:
return 1.0 - (self > x)
def __ne__(self, x) -> Tensor:
return (self < x) + (self > x) # type: ignore[override]
def __eq__(self, x) -> Tensor:
return 1.0 - (self != x) # type: ignore[override]
# ***** functional nn ops *****
def linear(self, weight: Tensor, bias: Optional[Tensor] = None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll: List[Callable[[Tensor], Tensor]]):
return reduce(lambda x, f: f(x), ll, self)
def layernorm(self, axis=-1, eps: float = 1e-5) -> Tensor:
y = self - self.mean(axis, keepdim=True)
return y.mul((y * y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(
self,
weight: Optional[Tensor],
bias: Optional[Tensor],
mean: Tensor,
invstd: Tensor,
) -> Tensor:
x = self - mean.reshape(shape=[1, -1, 1, 1])
if weight:
x = x * weight.reshape(shape=[1, -1, 1, 1])
ret = x.mul(
invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd
)
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training or p == 0:
return self
mask = (
Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p
).cast(dtypes.bool)
return self * mask * (1 / (1.0 - p))
def scaled_dot_product_attention(
self,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal:
attn_mask = (
Tensor.ones(
self.shape[-2],
key.shape[-2],
requires_grad=False,
device=self.device,
)
.tril(0)
.cast(dtypes.bool)
)
if attn_mask is not None and attn_mask.dtype == dtypes.bool:
attn_mask = (attn_mask == 0).where(-float("inf"), 0)
return (
self @ key.transpose(-2, -1) / math.sqrt(self.shape[-1]) + attn_mask
).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y: Tensor) -> Tensor:
return (-y * self.log() - (1 - y) * (1 - self).log()).mean()
def binary_crossentropy_logits(self, y: Tensor) -> Tensor:
return (
self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()
).mean()
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
# NOTE: self is a logits input
loss_mask = Y != ignore_index
y_counter = (
Tensor.arange(
self.shape[-1],
dtype=dtypes.int32,
requires_grad=False,
device=self.device,
)
.unsqueeze(0)
.expand(Y.numel(), self.shape[-1])
)
y = (
(y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
* loss_mask.reshape(-1, 1)
).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops *****
def cast(self, dtype: DType) -> Tensor:
return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
def bitcast(self, dtype: DType) -> Tensor:
assert (
self.dtype.itemsize == dtype.itemsize
), "can't bitcast mismatched dtype itemsizes"
return (
mlops.Cast.apply(self, dtype=dtype, bitcast=True)
if self.dtype != dtype
else self
)
def float(self) -> Tensor:
return self.cast(dtypes.float32)
def half(self) -> Tensor:
return self.cast(dtypes.float16)
# ***** convenience stuff *****
@property
def ndim(self) -> int:
return len(self.shape)
def numel(self) -> sint:
return prod(self.shape)
def element_size(self) -> int:
return self.dtype.itemsize
def nbytes(self) -> int:
return self.numel() * self.element_size()
def is_floating_point(self) -> bool:
return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
# TODO: remove the custom op and replace with threefry
def custom_random(out: Buffer):
Tensor._seed += 1
if DEBUG >= 2:
print(
f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}"
)
rng = np.random.default_rng(Tensor._seed)
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(
dtype=out.dtype.np, copy=False
)
out.copyin(rng_np_buffer.data)