1
0
Fork 0
tinygrab/tinygrad/ops.py

596 lines
17 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, Mapping
import functools
from enum import Enum, auto
from tinygrad.helpers import prod, DType, dedup
from tinygrad.shape.symbolic import Variable
from dataclasses import dataclass
"""
UnaryOps Enum class: This Enum class defines the unary operations that an accelerator must implement, along with toCpu. The current operations include exponential to base 2 (EXP2), logarithm to base 2 (LOG2), type casting (CAST), sine (SIN), square root (SQRT), reciprocal (RECIP), and negation (NEG). Note that RECIP doesn't have to be implemented for vectors, just scalars. Also, note that rdna3 only has RECIP and not DIV or POW, which are on the chopping block.
"""
class UnaryOps(Enum):
"""
The EXP2 operation represents the base-2 exponential function.
This operation calculates 2 raised to the power of the input value.
The LOG2 operation represents the base-2 logarithm function.
This operation calculates the logarithm base 2 of the input value.
The CAST operation is used for type conversion or casting.
For example, this operation can be used to convert a float number to an integer.
The SIN operation calculates the sine of the input value.
It represents the standard mathematical sin function.
The SQRT operation calculates the square root of the input value.
This is a standard mathematical operation that finds the number
that, when multiplied by itself, gives the original input value.
The RECIP operation calculates the reciprocal of the input value.
This means it calculates the number needed to multiply the input
in order to get 1 as a result (e.g., if the input is 2, the output would be 0.5).
Note: In rdna3, only RECIP is available and not DIV or POW.
The NEG operation calculates the negation of the input value.
This means it changes the sign of the input number (e.g., if the input is 5, the output would be -5).
"""
EXP2 = auto()
LOG2 = auto()
CAST = auto()
SIN = auto()
SQRT = auto()
RECIP = auto()
NEG = auto() # noqa: E702
class BinaryOps(Enum):
"""
This class defines an enumeration for various binary operations.
The enumerations are:
ADD - Addition operation
SUB - Subtraction operation
MUL - Multiplication operation
DIV - Division operation
MAX - Maximum of two operands operation
MOD - Modulo operation
CMPLT - Comparison less than operation
"""
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
MAX = auto()
MOD = auto()
CMPLT = auto() # noqa: E702
class TernaryOps(Enum):
"""
This is an Enum class for representing ternary operations.
Attributes:
MULACC: Represents a multiplication accumulation operation.
WHERE: Represents a where operation.
"""
MULACC = auto()
WHERE = auto() # noqa: E702
class ReduceOps(Enum):
"""
Define an enumeration for reduce operations.
The ReduceOps class is an enumeration that defines the possible reduce operations
available for use in other parts of a program. It includes two primary options -
SUM and MAX. These are defined using the built-in Enum class from Python's standard library.
"""
SUM = auto()
MAX = auto() # noqa: E702
class BufferOps(Enum):
"""
This class represents the different types of buffer operations.
Attributes:
LOAD: A buffer operation that loads a value into a buffer.
CONST: A buffer operation that sets a constant value in a buffer.
STORE: A buffer operation that stores a value from a buffer.
"""
LOAD = auto()
CONST = auto()
STORE = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum):
"""
Enum class for various movement operations.
Attributes:
RESHAPE: Represents a reshaping operation.
PERMUTE: Represents a permuting operation.
EXPAND: Represents an expanding operation.
PAD: Represents a padding operation.
SHRINK: Represents a shrinking operation.
STRIDE: Represents a striding operation.
AS_STRIDED: Represents an as-strided operation.
"""
RESHAPE = auto()
PERMUTE = auto()
EXPAND = auto()
PAD = auto()
SHRINK = auto()
STRIDE = auto()
AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum):
"""
Enumeration for load operation types.
Attributes:
EMPTY (auto()): Represents an empty load operation.
CONST (auto()): Represents a constant load operation, where the value is known and fixed.
FROM (auto()): Represents a load operation from another source.
CONTIGUOUS (auto()): Represents a contiguous load operation, where elements are stored in memory sequentially.
CUSTOM (auto()): Represents a custom load operation, where the behavior is user-defined or specified.
"""
EMPTY = auto()
CONST = auto()
FROM = auto()
CONTIGUOUS = auto()
CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
"""
This module defines the operations that can be performed on tensors.
Attributes:
Op (Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]): The union of all operation types supported by this module.
OpType (Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]): The union of all operation types as class objects.
"""
OpType = Union[
Type[UnaryOps],
Type[BinaryOps],
Type[ReduceOps],
Type[MovementOps],
Type[LoadOps],
Type[TernaryOps],
Type[BufferOps],
]
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.lazy import LazyBuffer
"""
A dataclass for MemBuffer.
Attributes:
idx (int): The index of the memory buffer.
dtype (DType): The datatype of the buffer.
st (ShapeTracker): The shape tracker for the buffer.
"""
@dataclass(frozen=True)
class MemBuffer:
idx: int
dtype: DType
st: ShapeTracker
"""
A dataclass for ConstBuffer.
Attributes:
val (Union[int, float]): The value of the constant buffer.
dtype (DType): The datatype of the buffer.
st (ShapeTracker): The shape tracker for the buffer.
"""
@dataclass(frozen=True)
class ConstBuffer:
val: Union[int, float]
dtype: DType
st: ShapeTracker
"""
A dataclass for ScheduleItem.
Attributes:
ast (LazyOp): The lazy operation to be performed.
out (LazyBuffer): The output buffer of the operation.
inputs (Tuple[LazyBuffer, ...]): The input buffers for the operation.
var_vals (Dict[Variable, int]): A dictionary mapping variables to their values.
"""
@dataclass(frozen=True)
class ScheduleItem:
ast: LazyOp
out: LazyBuffer
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
@dataclass(frozen=True)
class LazyOp:
"""
Data class for lazy operations.
Attributes:
op (Op): Operation to be performed.
src (Tuple[Union[LazyOp, LazyBuffer], ...]): Source data for the operation.
arg (Any): Optional argument for the operation. Default is None.
Properties:
hash (int): Cached property for the hash of the object.
buffers (Tuple[LazyBuffer, ...]): Cached property for the buffers in the source data.
Methods:
map_buffers(real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: Maps the buffers to real sources.
get_lazyops() -> List[LazyOp]: Returns a list of lazy operations.
replace_with_movement_ops(ops: List[Tuple[MovementOps, Tuple[Any, ...]]]) -> "LazyBuffer": Replaces the operation with movement operations.
"""
op: Op
src: Tuple[Union[LazyOp, LazyBuffer], ...]
arg: Any = None
def __repr__(self):
return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
@functools.cached_property
def buffers(self) -> Tuple[LazyBuffer, ...]:
"""
Cached property for the buffers in the source data.
Returns:
Tuple[LazyBuffer, ...]: The buffers in the source data.
"""
return tuple(dedup(sum([x.buffers for x in self.src], ())))
@functools.cached_property
def hash(self):
return hash((self.op, self.src, self.arg))
def __hash__(self):
"""
Cached property for the hash of the object.
Returns:
int: The hash of the object.
"""
return self.hash
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp:
"""
Maps the buffers to real sources.
Args:
real_srcs (Mapping[Any, Union[LazyBuffer, LazyOp]]): The mapping of buffers to real sources.
Returns:
LazyOp: A new lazy operation with mapped buffers.
"""
return LazyOp(
self.op,
tuple(
[
y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y]
for y in self.src
]
),
self.arg,
)
def get_lazyops(self) -> List[LazyOp]:
"""
Returns a list of lazy operations.
Returns:
List[LazyOp]: A list of lazy operations.
"""
return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(
self: LazyOp, ops: List[Tuple[MovementOps, Tuple[Any, ...]]]
) -> "LazyBuffer":
"""
Replaces the operation with movement operations.
Args:
ops (List[Tuple[MovementOps, Tuple[Any, ...]]]): List of tuples containing movement operations and their arguments.
Returns:
LazyBuffer: A new lazy buffer with replaced operations.
"""
assert isinstance(self.op, (UnaryOps, BinaryOps, TernaryOps))
srcs = [z.replace_with_movement_ops(ops) for z in self.src]
return srcs[0].e(self.op, *srcs[1:], arg=self.arg)
@property
def st(self):
"""
Retrieve the `st` attribute of an instance of a class.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
@property
def realized(self):
"""
Get the realized attribute of self.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
@property
def children(self):
"""
Get the children attribute of self.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
# movement ops
def reshape(self, _):
"""
Reshapes self based on input arguments.
Args:
_: The argument(s) to be used for reshaping.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
def pad(self, _):
"""
Pads self with specified values based on input arguments.
Args:
_: The argument(s) to be used for padding.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
def expand(self, _):
"""
Expands self based on input arguments.
Args:
_: The argument(s) to be used for expansion.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
def permute(self, _):
"""
Permutes self based on input arguments.
Args:
_: The argument(s) to be used for permutation.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
def shrink(self, _):
"""
Shrinks self based on input arguments.
Args:
_: The argument(s) to be used for shrinking.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
def stride(self, _):
"""
Applies stride to self based on input arguments.
Args:
_: The argument(s) to be used for applying stride.
Returns:
NotImplementedError: The operation is not implemented yet.
Raises:
NotImplementedError: This method is not implemented at the moment.
"""
raise NotImplementedError
# **************** independent FlopCounter ****************
@dataclass
class FlopCounter:
"""
Data class for tracking Floating Point Operations (FLOPs) and memory usage.
Attributes:
shape: Tuple[int, ...]
The shape of the data structure being tracked.
dtype: DType
The data type of the elements in the data structure.
flops: int
The number of floating point operations performed.
mem: Dict[int, int]
A dictionary mapping memory location IDs to their associated memory usage.
"""
shape: Tuple[int, ...]
dtype: DType
flops: int
mem: Dict[int, int]
@property
def mem_estimate(self):
"""
Calculate the total memory usage of the data structure.
Returns:
The sum of all memory usages in the `mem` attribute.
"""
return sum(self.mem.values())
def consume_flops(self):
"""
Consume all tracked floating point operations and return their count.
Postconditions:
The `flops` attribute is set to 0 after this function call.
Returns:
int: The number of consumed floating point operations.
"""
self.flops, ret = 0, self.flops
return ret
InterpretedFlopCounter: Dict[Op, Callable] = {
BufferOps.LOAD: lambda arg: FlopCounter(
arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize * arg.st.size()}
),
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
BufferOps.STORE: lambda self, arg: FlopCounter(
arg.st.shape,
arg.dtype,
self.consume_flops(),
{**self.mem, arg.idx: arg.dtype.itemsize * arg.st.size()},
),
UnaryOps.CAST: lambda self, arg: FlopCounter(
self.shape, arg[0], self.consume_flops(), self.mem
), # cast uses no flops
**{
op: lambda self: FlopCounter(
self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem
)
for op in UnaryOps
if op != UnaryOps.CAST
},
**{
op: lambda self, y: FlopCounter(
self.shape,
max(self.dtype, y.dtype),
self.consume_flops() + y.consume_flops() + prod(self.shape),
{**self.mem, **y.mem},
)
for op in BinaryOps
},
**{
op: lambda self, new_shape: FlopCounter(
new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem
)
for op in ReduceOps
},
TernaryOps.WHERE: lambda self, y, z: FlopCounter(
self.shape,
max(y.dtype, z.dtype),
self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape),
{**self.mem, **y.mem, **z.mem},
),
}
@functools.lru_cache(None)
def get_lazyop_info(ast: LazyOp) -> FlopCounter:
"""
Run the AST using an InterpretedFlopCounter and return the result.
Attributes:
ast (LazyOp): The abstract syntax tree to be evaluated.
Returns:
FlopCounter: The result of running the AST through the InterpretedFlopCounter.
"""
@functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs
def run_ast(ast):
"""
Recursively evaluate an abstract syntax tree using an InterpretedFlopCounter.
Attributes:
ast (Any): The current node in the abstract syntax tree to be evaluated.
Returns:
Any: The result of running the current node through the InterpretedFlopCounter.
"""
return InterpretedFlopCounter[ast.op](
*(
[run_ast(x) for x in ast.src]
+ ([ast.arg] if ast.arg is not None else [])
)
)
return run_ast(ast)