1
0
Fork 0

docstrings for tinygrad lazy

deepcrayon
Jeff Moe 2023-12-05 14:53:21 -07:00
parent 6dcc07a557
commit d7736d19ea
3 changed files with 527 additions and 2 deletions

View File

@ -0,0 +1,8 @@
tinygrad lazy
-------------
.. automodule:: tinygrad.lazy
:members:
:undoc-members:
:show-inheritance:

View File

@ -9,6 +9,7 @@ tinygrad
tinygrad-graph
tinygrad-helpers
tinygrad-jit
tinygrad-lazy
tinygrad-mlops
tinygrad-ops
tinygrad-realize

View File

@ -1,3 +1,16 @@
"""
**sys**: The sys module provides access to some variables used or maintained by the interpreter and to functions that interact strongly with the interpreter.
**math**: The math module supplies some mathematical functions and constants.
**typing**: This module supplies several helper classes and functions for type hints.
**weakref**: The weakref module is useful for creating weak references to objects, which can be useful for creating caches where the garbage collector is allowed to delete the cached values if memory is needed elsewhere.
**numpy**: NumPy is a Python library that stands for 'Numerical Python'. It is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays.
**tinygrad.helpers**: This module contains helper functions used throughout TinyGrad.
"""
from __future__ import annotations
import sys, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
@ -35,7 +48,21 @@ from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.device import Buffer
# lazy can recurse a lot
"""
**sys.setrecursionlimit(10000)**: Increases the recursion limit in Python to 10,000. This allows for deeper recursion in the program.
**OPT = getenv("OPT", 2)**: Gets the value of the environment variable "OPT". If it doesn't exist, sets OPT to 2.
**LAZYCACHE = getenv("LAZYCACHE", 1)**: Gets the value of the environment variable "LAZYCACHE". If it doesn't exist, sets LAZYCACHE to 1.
**(REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1)**: Sets boolean values for various optimizations based on the value of OPT. If OPT is 1 or greater, sets these to True; otherwise, False.
**MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2**: Sets boolean values for more specific optimizations based on the value of OPT. If OPT is 2 or greater, sets these to True; otherwise, False.
**PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3**: Sets boolean values for further optimizations based on the value of OPT. If OPT is 3 or greater, sets these to True; otherwise, False.
**PUSH_RESHAPES = OPT >= 4**: Sets a boolean value for an additional optimization based on the value of OPT. If OPT is 4 or greater, sets this to True; otherwise, False.
"""
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
@ -56,7 +83,17 @@ PUSH_RESHAPES = OPT >= 4
def _ast_reduceops(op: LazyOp) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
"""
Reduce operations in Abstract Syntax Trees (AST).
This function is designed to optimize the AST by reducing operations. It can also realize a binary op after the reduce, not just before. The function takes a single argument `op` of type `LazyOp`.
Attributes:
src (LazyOp): The source from which to retrieve the operation.
Returns:
LazyOp: The optimized `LazyOp` object after performing the reduction operation.
"""
src = op.src[0]
if not src.realized:
assert isinstance(
@ -73,6 +110,16 @@ def _ast_reduceops(op: LazyOp) -> LazyOp:
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(op: LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
"""
This function supports late merging an upstream Reduce op and even an Elementwise op above that.
Attributes:
op (LazyOp): The input lazy operation.
shape (Tuple[sint, ...]): The output shape of the operation.
Returns:
LazyOp: The transformed lazy operation with the late merging applied.
"""
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {
x: None for x in op.buffers
}
@ -117,6 +164,22 @@ def _ast_binaryops(op: LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
def _replace_bufferops(op: LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
"""
Replace buffer operations in a lazy op with new ones.
This function takes a lazy operation (op) as input and replaces its buffer operations
with new ones based on certain conditions. It returns the updated lazy operation and a list of base buffers.
Args:
op (LazyOp): The input lazy operation.
Returns:
Tuple[LazyOp, List[LazyBuffer]]: A tuple containing the updated lazy operation and
a list of base buffers.
Raises:
NotImplementedError: If a certain buffer is not handled by the function.
"""
replacements: Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
@ -140,6 +203,23 @@ def _replace_bufferops(op: LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
def get_movementroot(root: LazyBuffer, allow_contiguous=False) -> LazyBuffer:
"""
Recursively retrieve the root of a movement operation or contiguous data.
This function is used to locate the origin of a series of operations by
traversing back through the chain of operations until it reaches the original
source buffer. It will continue this process as long as the current root is
not realized and its operation type matches specific criteria.
:param root: The current root node in the operation tree.
:type root: LazyBuffer
:param allow_contiguous: A flag indicating whether to include operations
that result in contiguous data, defaults to False.
:type allow_contiguous: bool, optional
:return: The original root node of the operation tree or the current root if
it is realized or does not meet the specified criteria.
:rtype: LazyBuffer
"""
return (
get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous)
if not root.realized
@ -156,6 +236,19 @@ def get_movementroot(root: LazyBuffer, allow_contiguous=False) -> LazyBuffer:
def get_movementroot_contiguous(x: LazyBuffer) -> LazyBuffer:
"""
Recursively obtain the root of movement for a contiguous operation.
This function is used to identify and return the root of a series of operations
that lead up to a contiguous operation on a lazy buffer. It does this by checking
if the current operation is a contiguous operation, and if it is not, recursively
calling itself on the source operation until it finds the root of the contiguous operation.
:param x: The input lazy buffer to check for the root of its movement.
:type x: LazyBuffer
:return: The root of the movement for the contiguous operation.
:rtype: LazyBuffer
"""
return (
get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0]))
if not x.realized and x.op.op == LoadOps.CONTIGUOUS
@ -169,6 +262,19 @@ def get_movementroot_contiguous(x: LazyBuffer) -> LazyBuffer:
# NOTE: this is the canonical order
def vars_from_ast(ast: LazyOp) -> List[Variable]:
"""
Retrieve variables from abstract syntax tree (AST).
This function extracts all unique variables from the AST by leveraging
the `LazyOp` operations that belong to `BufferOps`. The resulting set of
variables is then sorted based on their expression string representation.
Attributes:
ast (LazyOp): Abstract syntax tree object to extract variables from.
Returns:
List[Variable]: Sorted list of unique variables extracted from AST.
"""
return sorted(
set.union(
*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()
@ -188,6 +294,21 @@ def create_lazybuffer(
dtype: DType,
base: Optional[LazyBuffer] = None,
):
"""
Create a lazy buffer for the given device, shape tracker, operation type, operation, data type, and optional base.
Args:
device (str): The device to create the lazy buffer on.
st (ShapeTracker): The shape tracker to use.
optype (OpType): The operation type.
op (LazyOp): The lazy operation.
dtype (DType): The data type of the lazy buffer.
base (Optional[LazyBuffer]): The optional base for the lazy buffer. Default is None.
Returns:
LazyBuffer: The created lazy buffer.
"""
# rewrite 0 size into a CONST
if 0 in st.shape:
return LazyBuffer(
@ -216,6 +337,23 @@ def create_lazybuffer(
class LazyBuffer:
"""
LazyBuffer class for lazy operations on buffers.
Attributes:
__deletable__ (tuple): Tuple containing the attribute 'op' which can be deleted.
device (str): Device where this buffer is located.
st (ShapeTracker): Shape tracker object.
optype (OpType): Operation type.
op (Optional[LazyOp]): Lazy operation object.
dtype (DType): Data type of the buffer.
src (Optional[Buffer]): Source buffer, if any. Default is None.
base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None.
output_buffer (Optional[Buffer]): Output buffer. Default is None.
children (WeakSet[LazyBuffer]): Weak set of child lazy buffers.
views (WeakSet[LazyBuffer]): Weak set of view lazy buffers.
"""
__deletable__ = ("op",)
def __init__(
@ -228,6 +366,18 @@ class LazyBuffer:
src: Optional[Buffer] = None,
base: Optional[LazyBuffer] = None,
):
"""
Initializes a new instance of the LazyBuffer class.
Args:
device (str): Device where this buffer is located.
st (ShapeTracker): Shape tracker object.
optype (OpType): Operation type.
op (Optional[LazyOp]): Lazy operation object.
dtype (DType): Data type of the buffer.
src (Optional[Buffer]): Source buffer, if any. Default is None.
base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None.
"""
self.device, self.st, self.shape, self.optype, self._dtype, self._realized = (
device,
st,
@ -258,51 +408,151 @@ class LazyBuffer:
@property
def base(self):
"""
Return the base of this LazyBuffer.
Returns:
The base of this LazyBuffer if it is not None, else self.
"""
return self._base if self._base is not None else self
def is_unrealized_const(self):
"""
Check whether this LazyBuffer is an unrealized constant.
Returns:
True if the buffer is unrealized and its base operation is a CONST LoadOp, False otherwise.
"""
return not self.realized and self.base.op.op == LoadOps.CONST
def is_unrealized_contiguous_const(self):
"""
Check whether this LazyBuffer is an unrealized contiguous constant.
Returns:
True if the buffer is both unrealized and contiguous, False otherwise.
"""
return self.is_unrealized_const() and self.st.contiguous
@property
def realized(self):
"""
Return whether this LazyBuffer is realized.
Returns:
True if the buffer is realized, False otherwise.
"""
return self.base._realized
@realized.setter
def realized(self, val: Buffer):
"""
Set the realization of this LazyBuffer.
Args:
val (Buffer): The buffer to set as the realization of this LazyBuffer.
Raises:
AssertionError: If _base is not None when trying to set the realized value.
"""
assert self._base is None, "no setting realized of based LazyBuffers"
self._realized = val
@property
def dtype(self):
"""
Get the data type of this LazyBuffer.
Returns:
The data type of this LazyBuffer.
"""
return self.base._dtype
@dtype.setter
def dtype(self, val: DType):
"""
Set the data type for this LazyBuffer.
Args:
val (DType): The data type to set for this LazyBuffer.
Raises:
AssertionError: If attempting to set the dtype of a based LazyBuffer.
"""
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
def __repr__(self):
"""
Get a string representation of this LazyBuffer.
Returns:
A string containing the shape, data type, operation, and storage type of this LazyBuffer.
"""
return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
def _device_extra_args(self) -> Dict[str, str]:
"""
Get extra arguments for the device based on its representation.
Returns:
A dictionary containing any extra arguments necessary for the device.
"""
return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@property
def buffers(self) -> Tuple[LazyBuffer, ...]:
"""
Return a tuple containing the instance of `LazyBuffer`.
Returns:
Tuple[LazyBuffer]: A tuple containing the instance of `LazyBuffer`.
"""
return (self,)
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]):
"""
Retrieve the corresponding `LazyBuffer` or `LazyOp` object from the mapping of sources.
Args:
real_srcs (Mapping[Any, Union[LazyBuffer, LazyOp]]): A mapping of objects, where keys are any hashable objects and values are either `LazyBuffer` or `LazyOp`.
Returns:
Union[LazyBuffer, LazyOp]: The corresponding `LazyBuffer` or `LazyOp` object. If the instance is not found in the mapping, it returns itself.
"""
return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]:
"""
Return an empty list of `LazyOp` objects.
This method is a placeholder and always returns an empty list. Subclasses may override this method to provide specific behavior.
Returns:
List[LazyOp]: An empty list.
"""
return []
# *** scheduling ***
def schedule(self, seen: Optional[Set[LazyBuffer]] = None) -> List[ScheduleItem]:
"""
Schedules the computation of this lazy buffer.
Args:
seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to None.
Returns:
List[ScheduleItem]: A list of schedule items for this buffer's computations.
Attributes:
seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to an empty set if not provided.
ret (List[ScheduleItem]): List of schedule items for this buffer's computations.
var_vals (Dict): Merged dictionary of variable values from this buffer and its operand buffers.
op (ASTNode): The abstract syntax tree node representing the operation to be performed on this buffer.
base_bufs (List[Buffer]): List of base buffers for this buffer's computation.
"""
if seen is None:
seen = set()
if self in seen or self.realized or self.is_unrealized_const():
@ -397,6 +647,22 @@ class LazyBuffer:
arg=None,
src: Optional[LazyBuffer] = None,
) -> LazyBuffer:
"""
Load operation factory method.
Creates and returns a new `LazyBuffer` object based on the given parameters. This is a static method and does not require an instance of the class.
Attributes:
op: The operation to be performed.
shape (Tuple[sint, ...]): The shape of the data to be loaded.
dtype: The data type of the data to be loaded.
device (str): The device where the data will be loaded onto.
arg: An optional argument for the operation. Default is None.
src (Optional[LazyBuffer]): An optional source `LazyBuffer` object. Default is None.
Returns:
LazyBuffer: A new `LazyBuffer` object created with the given parameters.
"""
return create_lazybuffer(
device,
ShapeTracker.from_shape(shape),
@ -407,6 +673,12 @@ class LazyBuffer:
# create a constant with the shape and dtype of self
def const(self, val: Union[float, int]) -> LazyBuffer:
"""
Creates a new constant `LazyBuffer` object based on the shape and data type of this instance.
Returns:
LazyBuffer: A new constant `LazyBuffer` object with the same shape and data type as this instance.
"""
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return (
LazyBuffer.loadop(
@ -421,6 +693,16 @@ class LazyBuffer:
)
def copy_to_device(self, device: str) -> LazyBuffer:
"""
This method is used to copy a lazy buffer to a specified device. It will check if the buffer is already on the target device and return it directly if true. Otherwise, it will create a new lazy buffer on the target device.
Attributes:
self (LazyBuffer): The source lazy buffer.
device (str): The target device to copy the buffer to.
Returns:
LazyBuffer: The copied lazy buffer on the target device.
"""
# back off a FROM if it's a double FROM
if (
not self.realized
@ -433,6 +715,15 @@ class LazyBuffer:
)
def contiguous(self: LazyBuffer) -> LazyBuffer:
"""
This method is used to ensure a lazy buffer is stored in a contiguous manner. If the source buffer is already contiguous, it will return itself directly. Otherwise, it will create a new contiguous lazy buffer.
Attributes:
self (LazyBuffer): The source lazy buffer.
Returns:
LazyBuffer: A contiguous version of the source lazy buffer.
"""
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST:
return self # all LoadOps are already contiguous (except CONST)
if (
@ -456,6 +747,15 @@ class LazyBuffer:
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
"""
Create a new `LazyBuffer` object from a numpy array.
Attributes:
x (np.ndarray): The numpy array to be used for creating the `LazyBuffer`.
Returns:
LazyBuffer: A new `LazyBuffer` object created from the input numpy array.
"""
return LazyBuffer(
"CPU",
ShapeTracker.from_shape(x.shape),
@ -466,6 +766,17 @@ class LazyBuffer:
)
def cast(self, dtype: DType, bitcast: bool = False):
"""
Cast the elements of this buffer to a new data type.
Attributes:
dtype (DType): The desired data type for the elements of this buffer.
bitcast (bool): Whether to allow a bit-level cast, which is faster but may result in unspecified behavior
if used incorrectly. Defaults to `False`.
Returns:
self: This method modifies the `LazyBuffer` object in-place and returns it for chaining purposes.
"""
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
# *** elementwise ops ***
@ -476,6 +787,26 @@ class LazyBuffer:
*srcs: LazyBuffer,
arg: Optional[Any] = None,
) -> LazyBuffer:
"""
This method performs an operation on the input buffers and returns a new LazyBuffer.
:param self: The instance of the LazyBuffer class.
:type self: LazyBuffer
:param op: The operation to be performed (UnaryOps, BinaryOps, TernaryOps).
:type op: Union[UnaryOps, BinaryOps, TernaryOps]
:param srcs: The input buffers for the operation.
:type srcs: LazyBuffer
:param arg: An optional argument for certain operations, defaults to None.
:type arg: Optional[Any], optional
:return: A new LazyBuffer with the result of the operation.
:rtype: LazyBuffer
Attributes:
srcs (LazyBuffer): The input buffers for the operation. Includes self.
out_device (str): The output device.
out_shape (Tuple[int, ...]): The output shape.
out_dtype (DType): The output data type.
"""
# srcs includes self
srcs = (self,) + srcs
@ -551,6 +882,19 @@ class LazyBuffer:
def _reduce_op(
self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]
) -> LazyBuffer:
"""
Create a new LazyBuffer with reduced dimensions.
This method is used to reduce the dimensions of the current `LazyBuffer` object by applying a reduction operation specified by `op`. The reduced shape is given by `new_shape`. If the current shape and `new_shape` are equal, this method returns the original LazyBuffer.
Attributes:
self (LazyBuffer): The `LazyBuffer` object on which the method is called.
op (ReduceOps): The reduction operation to be applied.
new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer.
Returns:
LazyBuffer: A new `LazyBuffer` object with reduced dimensions.
"""
if self.shape == tuple(new_shape):
return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
@ -566,6 +910,19 @@ class LazyBuffer:
)
def r(self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]) -> LazyBuffer:
"""
Alias for `_reduce_op`.
This method is an alias for `_reduce_op` and provides another way to call it. It takes the same arguments as `_reduce_op` and returns a new `LazyBuffer` with reduced dimensions.
Attributes:
self (LazyBuffer): The `LazyBuffer` object on which the method is called.
op (ReduceOps): The reduction operation to be applied.
new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer.
Returns:
LazyBuffer: A new `LazyBuffer` object with reduced dimensions.
"""
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if (
not all_int(self.shape)
@ -597,6 +954,16 @@ class LazyBuffer:
# *** movement ops ***
def reshape(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer:
"""
Reshapes the buffer.
Attributes:
self (LazyBuffer): The current lazy buffer.
arg (Tuple[sint, ...]): The new shape for the buffer.
Returns:
LazyBuffer: The reshaped lazy buffer.
"""
if self.shape == arg:
return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
@ -608,6 +975,16 @@ class LazyBuffer:
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
def pad(self: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer:
"""
Pad the buffer object.
This method pads the current buffer with specified start and end indices. If all padding values are 0, it simply returns the original buffer. If the buffer is not realized and the last operation was a pad operation, the pad operation is combined with the new one. Otherwise, a new movement operation is created with the pad argument.
:param arg: A tuple of tuples containing start and end padding values.
:type arg: Tuple[Tuple[int, int], ...]
:return: The padded buffer object.
:rtype: LazyBuffer
"""
if all(b == 0 and e == 0 for b, e in arg):
return self
if not self.realized and self.op.op == MovementOps.PAD:
@ -619,6 +996,23 @@ class LazyBuffer:
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer:
"""
Expand the current LazyBuffer based on the given argument.
This function checks if the shape of the current LazyBuffer is equal to the provided argument.
If so, it returns the LazyBuffer itself. If not, and if the LazyBuffer hasn't been realized yet
and its operation is an expansion operation, it returns the source of the operation with index 0,
also expanded using the provided argument. Otherwise, it creates a new LazyBuffer by calling
the movement_op function with the result of expanding the current state, the MovementOps.EXPAND
operation, and the provided argument.
:param self: The current LazyBuffer object.
:type self: LazyBuffer
:param arg: A tuple containing sint objects representing the new shape of the LazyBuffer.
:type arg: Tuple[sint, ...]
:return: A new LazyBuffer with the expanded shape.
:rtype: LazyBuffer
"""
if self.shape == arg:
return self
if not self.realized and self.op.op == MovementOps.EXPAND:
@ -626,6 +1020,18 @@ class LazyBuffer:
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
def permute(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer:
"""
Permute the current LazyBuffer based on the given argument.
This function is not yet implemented.
:param self: The current LazyBuffer object.
:type self: LazyBuffer
:param arg: A tuple containing int objects representing the permutation of the dimensions.
:type arg: Tuple[int, ...]
:return: A new LazyBuffer with the permuted shape.
:rtype: LazyBuffer
"""
if arg == tuple(range(len(self.shape))):
return self
if not self.realized and self.op.op == MovementOps.PERMUTE:
@ -667,6 +1073,16 @@ class LazyBuffer:
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
def shrink(self: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
"""
Shrinks the buffer based on the provided arguments.
Attributes:
self (LazyBuffer): The lazy buffer to be shrunk.
arg (Tuple[Tuple[sint, sint], ...]): A tuple of tuples containing the start and end indices for each dimension of the buffer.
Returns:
LazyBuffer: The shrunken lazy buffer.
"""
if all(b - a == s for s, (a, b) in zip(self.shape, arg)):
return self
if not self.realized and self.op.op == MovementOps.SHRINK:
@ -678,6 +1094,16 @@ class LazyBuffer:
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer:
"""
Applies a stride operation to the buffer based on the provided arguments.
Attributes:
self (LazyBuffer): The lazy buffer to have the stride operation applied to.
arg (Tuple[int, ...]): A tuple of integers representing the strides for each dimension of the buffer.
Returns:
LazyBuffer: The lazy buffer after the stride operation has been applied.
"""
if all(a == 1 for a in arg):
return self
if not self.realized and self.op.op == MovementOps.STRIDE:
@ -692,6 +1118,44 @@ class LazyBuffer:
op: MovementOps,
arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]],
) -> LazyBuffer:
"""
Perform movement operation on the current instance.
This function checks certain conditions and based on them either replaces the current instance with a new one
created using movement operations or creates a new `LazyBuffer` object.
Parameters:
st (ShapeTracker): Shape tracker for the operation.
op (MovementOps): Movement operation to be performed.
arg (Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]): Arguments for the movement operation.
Returns:
LazyBuffer: The result of the movement operation.
Attributes:
SHUFFLE_MOVEMENT_OPS (bool): If True, shuffle movement operations are performed.
self.realized (bool): Indicates if the object is realized or not.
self.optype (BinaryOps): The type of operation to be performed.
self.children (list): List of children for the current instance.
MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE (MovementOps): Enumeration values for shrink, stride and permute operations.
self.op.op (UnaryOps): The type of unary operation to be performed.
PUSH_RESHAPES (bool): If True, reshapes are pushed.
Attributes:
REMOVE_MOVEMENT_NOPS (bool): If True, no-operation movement are removed.
get_movementroot (function): Function to get the root of movement operations.
self.st.contiguous (bool): Indicates if the shape tracker is contiguous or not.
prod (function): Function to calculate the product of a tuple of integers.
Attributes:
self.device (str): Device on which the operation will be performed.
st (ShapeTracker): Shape tracker for the operation.
MovementOps (Enum): Enum class for movement operations.
LazyOp (class): Class representing a lazy operation.
self.dtype (data-type): Data type of the elements in the buffer.
self.base (object): Base object for the current instance.
"""
if (
SHUFFLE_MOVEMENT_OPS
and not self.realized
@ -722,12 +1186,33 @@ class LazyBuffer:
def replace_with_movement_ops(
self: LazyBuffer, ops: List[Tuple[MovementOps, Any]]
) -> LazyBuffer:
"""
This method takes a list of tuples as an argument where each tuple contains a movement operation and its corresponding argument.
Attributes:
self (LazyBuffer): The lazy buffer instance on which the operations are to be performed.
ops (List[Tuple[MovementOps, Any]]): A list of tuples where each tuple contains a MovementOps enum member and its corresponding argument.
Returns:
LazyBuffer: The updated lazy buffer instance after performing all the movement operations in sequence.
This method iterates over the list of tuples and for each tuple, it retrieves the corresponding function from the MOVEMENT_OPS_DISPATCHER dictionary using the MovementOps enum member as the key.
The function is then called with the lazy buffer instance and its argument to perform the movement operation.
The result of this operation is then stored back in the lazy buffer instance for subsequent operations.
Once all the operations have been performed, the final updated lazy buffer instance is returned.
"""
y = self
for op, arg in ops:
y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
"""
Constants and configurations.
Attributes:
UNSAFE_PAD_OPS (set): A set of unsafe padding operations. These include division, comparison less than, base-2 logarithm, base-2 exponentiation, and reciprocal (1/x).
"""
UNSAFE_PAD_OPS = {
BinaryOps.DIV,
BinaryOps.CMPLT,
@ -738,6 +1223,15 @@ UNSAFE_PAD_OPS = {
def _push_movement_ops(srcs: Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
"""
This function pushes movement operations to the sources of a lazy buffer.
Attributes:
srcs (Tuple[LazyBuffer, ...]): A tuple of LazyBuffer objects.
Returns:
Tuple[LazyBuffer, ...]: A tuple of updated LazyBuffer objects with movement operations pushed to their sources.
"""
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Any]] = []
@ -779,3 +1273,25 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}
"""
This dictionary acts as a dispatcher for various movement operations on `LazyBuffer` objects. It maps each operation to its corresponding function in the `LazyBuffer` class.
Attributes:
MovementOps (Enum): An enumeration of different movement operations like reshape, expand, shrink, permute, pad, and stride.
LazyBuffer (Class): The `LazyBuffer` class which contains the methods corresponding to each operation in this dispatcher.
Dict[MovementOps, Callable] (Type Hinting): Type hint indicating that keys are movement operations and values are their corresponding functions in `LazyBuffer`.
The dictionary keys:
MovementOps.RESHAPE: Corresponds to the reshape operation, mapped to `LazyBuffer.reshape` method.
MovementOps.EXPAND: Corresponds to the expand operation, mapped to `LazyBuffer.expand` method.
MovementOps.SHRINK: Corresponds to the shrink operation, mapped to `LazyBuffer.shrink` method.
MovementOps.PERMUTE: Corresponds to the permute operation, mapped to `LazyBuffer.permute` method.
MovementOps.PAD: Corresponds to the pad operation, mapped to `LazyBuffer.pad` method.
MovementOps.STRIDE: Corresponds to the stride operation, mapped to `LazyBuffer.stride` method.
"""