codegen kernel docstrings
parent
fe8d9753f0
commit
bd47ccaf04
|
@ -0,0 +1,12 @@
|
|||
tinygrad codegen.kernel
|
||||
------------------------
|
||||
|
||||
.. note:: You likely want the upstream tinygrad, not tinygrab.
|
||||
Tinygrab contains AI generated docstrings for a tinygrad snapshot.
|
||||
Upstream: https://tinygrad.org
|
||||
|
||||
.. automodule:: tinygrad.codegen.kernel
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
@ -20,6 +20,7 @@ tinygrad
|
|||
tinygrad-realize
|
||||
tinygrad-tensor
|
||||
tinygrad-codegen-kernel
|
||||
tinygrad-codegen-linearizer
|
||||
tinygrad-runtime-ops_clang
|
||||
tinygrad-runtime-ops_cpu
|
||||
tinygrad-runtime-ops_cuda
|
||||
|
|
|
@ -34,6 +34,21 @@ from enum import Enum, auto
|
|||
|
||||
|
||||
class OptOps(Enum):
|
||||
"""
|
||||
This class represents an enumeration of optimization operations.
|
||||
|
||||
Attributes:
|
||||
UPCAST (auto()): Represents the operation to upcast a data type.
|
||||
UPCASTMID (auto()): Represents the operation to upcast data types in the middle of a sequence.
|
||||
UNROLL (auto()): Represents the operation to unroll a loop.
|
||||
LOCAL (auto()): Represents the operation to make a variable local.
|
||||
LASTLOCAL (auto()): Represents the operation to make the last variable in a sequence local.
|
||||
GROUP (auto()): Represents the operation to group variables.
|
||||
GROUPTOP (auto()): Represents the operation to group variables at the top level.
|
||||
NOLOCALS (auto()): Represents the operation to remove all local variables.
|
||||
PADTO (auto()): Represents the operation to pad a sequence to a specific length.
|
||||
"""
|
||||
|
||||
UPCAST = auto()
|
||||
UPCASTMID = auto()
|
||||
UNROLL = auto()
|
||||
|
@ -45,21 +60,61 @@ class OptOps(Enum):
|
|||
PADTO = auto() # noqa: E702
|
||||
|
||||
def __lt__(self, x: OptOps):
|
||||
"""
|
||||
Compares this instance's value with another instance's value.
|
||||
|
||||
Args:
|
||||
x (OptOps): The other instance of the OptOps enumeration to compare against.
|
||||
|
||||
Returns:
|
||||
bool: True if this instance's value is less than the other instance's value, False otherwise.
|
||||
"""
|
||||
return self.value < x.value
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
"""
|
||||
Data class for operation options.
|
||||
|
||||
Attributes:
|
||||
op (OptOps): The operation to perform.
|
||||
axis (Optional[int]): The axis along which the operation is performed. Defaults to None.
|
||||
amt (Optional[int]): The amount or value used in the operation. Defaults to None.
|
||||
"""
|
||||
|
||||
op: OptOps
|
||||
axis: Optional[int] = None
|
||||
amt: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Return a string representation of the object.
|
||||
|
||||
Returns:
|
||||
str: A string in the format "Opt(op=<op>, axis=<axis>, amt=<amt>)".
|
||||
"""
|
||||
return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
"""
|
||||
Data class for the Tensor Core.
|
||||
|
||||
Attributes:
|
||||
device (str): The device on which the tensor core will be used.
|
||||
dims (List[int]): List of integers representing dimensions.
|
||||
dtype_in (DType): Input data type.
|
||||
dtype_out (DType): Output data type.
|
||||
threads (List[Tuple[int, int]]): List of tuples where each tuple contains a TC dimension and an amount that constructs the warp thread structure.
|
||||
upcast_dim (int): The TC dimension to upcast.
|
||||
thread_local_aliases (List[List[List[int]]]): A list of lists of lists containing integers defining alias for each TC dimension.
|
||||
For example: [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] where 1 is warp threads, -1 is upcast, and 0 is unrolled.
|
||||
thread_local_sizes (List[int]): List of integers representing the number of elements stored in registers for each TC dimension in each thread.
|
||||
arch (Optional[str]): Optional architecture parameter. Default is None.
|
||||
"""
|
||||
|
||||
device: str
|
||||
dims: List[int]
|
||||
dtype_in: DType
|
||||
|
@ -147,6 +202,16 @@ tensor_cores: Dict[str, List[TensorCore]] = {
|
|||
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
"""
|
||||
A named tuple for representing a local buffer in memory.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the local buffer.
|
||||
size (int): The size of the local buffer.
|
||||
dtype (DType, optional): The data type of the elements in the local buffer. Defaults to dtypes.float32.
|
||||
realized (None, optional): A placeholder for future functionality. Defaults to None.
|
||||
"""
|
||||
|
||||
name: str
|
||||
size: int
|
||||
dtype: DType = dtypes.float32
|
||||
|
@ -157,6 +222,19 @@ class LocalBuffer(NamedTuple):
|
|||
|
||||
|
||||
class LinearizerOptions(NamedTuple):
|
||||
"""
|
||||
A named tuple for representing options related to linearizing memory accesses.
|
||||
|
||||
Attributes:
|
||||
device (str, optional): The target device for the linearization. Defaults to "".
|
||||
supports_float4 (bool, optional): Whether the target device supports float4 data type. Defaults to True.
|
||||
supports_float4_alu (bool, optional): Whether the target device supports float4 ALU operations. Defaults to True.
|
||||
has_local (bool, optional): Whether the target device has local memory. Defaults to True.
|
||||
has_shared (bool, optional): Whether the target device has shared memory. Defaults to True.
|
||||
global_max (Optional[List[int]], optional): The maximum global dimensions for linearization. Defaults to None.
|
||||
local_max (Optional[List[int]], optional): The maximum local dimensions for linearization. Defaults to None.
|
||||
"""
|
||||
|
||||
device: str = ""
|
||||
# TODO: make this generic with a list of supported types
|
||||
supports_float4: bool = True
|
||||
|
@ -169,6 +247,31 @@ class LinearizerOptions(NamedTuple):
|
|||
|
||||
|
||||
class Kernel:
|
||||
"""
|
||||
The Kernel class represents a single kernel in the linearizer. It contains information about
|
||||
the AST (Abstract Syntax Tree), options, and various buffers and shape trackers used during
|
||||
the linearization process. This class also provides methods for simplifying and optimizing
|
||||
the linearized code.
|
||||
|
||||
Attributes:
|
||||
ast (LazyOp): The abstract syntax tree representing the kernel's operations.
|
||||
opts (Optional[LinearizerOptions]): The options used during the linearization process.
|
||||
info (FlopCounter): Information about the floating-point operations in the kernel.
|
||||
reduceop (Optional[Any]): The single allowed reduce operation in an AST, if it exists.
|
||||
bufs (List[Union[MemBuffer, ConstBuffer, LocalBuffer]]): The list of unique buffers used by the kernel.
|
||||
earlybufs (List[Any]): The list of buffers before the reduce operation, if any.
|
||||
full_buf_index (int): The index of the buffer with all axes.
|
||||
sts (List[ShapeTracker]): The shape trackers for each buffer in the kernel.
|
||||
applied_opts (List[Opt]): The list of optimization options that have been applied to the kernel.
|
||||
group_for_reduce (List[int]): Unknown.
|
||||
upcasted (int): A flag indicating whether an upcast operation has been performed on the kernel.
|
||||
local_dims (int): The number of local dimensions in the kernel.
|
||||
local_alias (Dict[int, LocalBuffer]): A dictionary mapping integers to local buffers.
|
||||
tensor_core (Optional[TensorCore]): Information about the tensor core being used, if any.
|
||||
dont_use_locals (bool): A flag indicating whether local buffers should be used in the kernel.
|
||||
applied_opts_cache (Optional[List[Opt]]): A cache of optimization options that have been applied to the kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, ast: LazyOp, opts: Optional[LinearizerOptions] = None):
|
||||
self.opts = (
|
||||
opts
|
||||
|
@ -240,6 +343,13 @@ class Kernel:
|
|||
self.applied_opts_cache: Optional[List[Opt]] = None
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Creates a deep copy of the current Kernel object. This can be useful for creating new kernels based on existing ones
|
||||
without modifying the original kernel.
|
||||
|
||||
Returns:
|
||||
A deep copy of the current Kernel object.
|
||||
"""
|
||||
ret = type(self).__new__(type(self))
|
||||
|
||||
# base linearizer params
|
||||
|
@ -282,10 +392,25 @@ class Kernel:
|
|||
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]:
|
||||
"""
|
||||
Membuffers attribute.
|
||||
|
||||
Returns:
|
||||
List[MemBuffer]: A list of MemBuffer objects.
|
||||
"""
|
||||
return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
# TODO: these need more tests or it might silently be no-op
|
||||
def shape_offsets(self, i: int):
|
||||
"""
|
||||
Compute the offsets of the shape.
|
||||
|
||||
Args:
|
||||
i (int): The index for which to compute the offsets.
|
||||
|
||||
Returns:
|
||||
itertools.product: An iterator that computes the cartesian product of input iterables.
|
||||
"""
|
||||
return (
|
||||
itertools.product(
|
||||
*[
|
||||
|
@ -298,6 +423,15 @@ class Kernel:
|
|||
)
|
||||
|
||||
def float4_axis(self, i: int):
|
||||
"""
|
||||
Compute the float4 axis.
|
||||
|
||||
Args:
|
||||
i (int): The index for which to compute the float4 axis.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of integers representing the float4 axis.
|
||||
"""
|
||||
return [
|
||||
x - (self.shape_len - self.upcasted)
|
||||
for x in self.sts[i].unit_stride_axes()
|
||||
|
@ -305,6 +439,15 @@ class Kernel:
|
|||
]
|
||||
|
||||
def upcasted_axis(self, i: int):
|
||||
"""
|
||||
Compute the upcasted axis.
|
||||
|
||||
Args:
|
||||
i (int): The index for which to compute the upcasted axis.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int, bool]]: A list of tuples containing integers and a boolean value.
|
||||
"""
|
||||
return list(
|
||||
zip(
|
||||
self.sts[i].shape[self.shape_len - self.upcasted :],
|
||||
|
@ -321,6 +464,15 @@ class Kernel:
|
|||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i: int) -> List[int]:
|
||||
"""
|
||||
Calculate access offsets for a given index.
|
||||
|
||||
Attributes:
|
||||
i (int): The index to calculate access offsets for.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of calculated access offsets.
|
||||
"""
|
||||
if self.upcasted == 0:
|
||||
return [0]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
|
@ -341,6 +493,15 @@ class Kernel:
|
|||
]
|
||||
|
||||
def get_upcast_dim(self, i: int) -> List[int]:
|
||||
"""
|
||||
Get dimensions that need to be upcasted.
|
||||
|
||||
Attributes:
|
||||
i (int): The index to check for dimensions that need to be upcasted.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of dimensions that need to be upcasted.
|
||||
"""
|
||||
should_upcast = self.opts.supports_float4 and (
|
||||
self.bufs[i].dtype in [dtypes.float32, dtypes.float16]
|
||||
or isinstance(self.bufs[i].dtype, ImageDType)
|
||||
|
@ -355,6 +516,17 @@ class Kernel:
|
|||
|
||||
@property
|
||||
def first_reduce(self) -> int:
|
||||
"""
|
||||
Calculate the index of the first reduction axis.
|
||||
|
||||
Attributes:
|
||||
self.sts (List[SomeObject]): A list of objects with a shape attribute.
|
||||
self.shape_len (int): The length of the shape attribute of an object in `self.sts`.
|
||||
self.upcasted (int): The number of upcasted dimensions.
|
||||
|
||||
Returns:
|
||||
int: The index of the first reduction axis.
|
||||
"""
|
||||
return [
|
||||
x != y
|
||||
for x, y in zip(
|
||||
|
@ -365,22 +537,72 @@ class Kernel:
|
|||
|
||||
@property
|
||||
def output_shape(self) -> Tuple[sint, ...]:
|
||||
"""
|
||||
Get the shape of the first object in `self.sts`.
|
||||
|
||||
Attributes:
|
||||
self.sts (List[SomeObject]): A list of objects with a shape attribute.
|
||||
|
||||
Returns:
|
||||
Tuple[sint, ...]: The shape of the first object in `self.sts`.
|
||||
"""
|
||||
return self.sts[0].shape
|
||||
|
||||
@property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
"""
|
||||
Get the shape of the object at index `self.full_buf_index` in `self.sts`.
|
||||
|
||||
Attributes:
|
||||
self.sts (List[SomeObject]): A list of objects with a shape attribute.
|
||||
self.full_buf_index (int): The index of the object to get the shape from.
|
||||
|
||||
Returns:
|
||||
Tuple[sint, ...]: The shape of the object at `self.full_buf_index` in `self.sts`.
|
||||
"""
|
||||
return self.sts[self.full_buf_index].shape
|
||||
|
||||
@property
|
||||
def full_unupcasted_shape(self) -> Tuple[sint, ...]:
|
||||
"""
|
||||
Get the unupcasted shape of the object at index `self.full_buf_index` in `self.sts`.
|
||||
|
||||
Attributes:
|
||||
self.full_shape (Tuple[sint, ...]): The shape of the object at `self.full_buf_index` in `self.sts`.
|
||||
self.upcasted (int): The number of upcasted dimensions.
|
||||
|
||||
Returns:
|
||||
Tuple[sint, ...]: The unupcasted shape of the object at `self.full_buf_index` in `self.sts`.
|
||||
"""
|
||||
return self.full_shape[: self.shape_len - self.upcasted]
|
||||
|
||||
@property
|
||||
def shape_len(self) -> int:
|
||||
"""
|
||||
Get the length of the shape attribute of an object in `self.sts`.
|
||||
|
||||
Attributes:
|
||||
self.sts (List[SomeObject]): A list of objects with a shape attribute.
|
||||
|
||||
Returns:
|
||||
int: The length of the shape attribute of an object in `self.sts`.
|
||||
"""
|
||||
return len(self.sts[0].shape)
|
||||
|
||||
@property
|
||||
def upcast_in_mid_reduce_axes(self) -> List[int]:
|
||||
"""
|
||||
Get a list of indices where the dimensions are equal in both `self.full_shape` and `self.sts[0].shape`.
|
||||
|
||||
Attributes:
|
||||
self.first_reduce (int): The index of the first reduction axis.
|
||||
self.group_for_reduce (List[int]): A list of integers representing groups for reduction.
|
||||
self.full_shape (Tuple[sint, ...]): The shape of the object at `self.full_buf_index` in `self.sts`.
|
||||
self.sts[0].shape (Tuple[sint, ...]): The shape of the first object in `self.sts`.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of indices where the dimensions are equal in both `self.full_shape` and `self.sts[0].shape`.
|
||||
"""
|
||||
return [
|
||||
j
|
||||
for j in range(
|
||||
|
@ -391,6 +613,20 @@ class Kernel:
|
|||
|
||||
@property
|
||||
def global_dims(self) -> int:
|
||||
"""
|
||||
Calculate and return the difference between first_reduce and local_dims attributes.
|
||||
|
||||
Attributes:
|
||||
self.first_reduce (int): The first reduced dimension.
|
||||
self.local_dims (int): The local dimensions.
|
||||
|
||||
Returns:
|
||||
int: The difference between self.first_reduce and self.local_dims.
|
||||
|
||||
Notes:
|
||||
This method is a property, meaning it can be accessed like an attribute on an instance of the class.
|
||||
It's important to note that there are eight chunks of the shape.
|
||||
"""
|
||||
return self.first_reduce - self.local_dims
|
||||
|
||||
# there's eight chunks of the shape
|
||||
|
@ -404,6 +640,23 @@ class Kernel:
|
|||
# purple -- reduce upcasted
|
||||
# yellow -- normal upcasted dimensions
|
||||
def colors(self) -> List[str]:
|
||||
"""
|
||||
Generate a list of color codes based on the dimensions of the object.
|
||||
|
||||
Attributes:
|
||||
global_dims (int): Number of global dimensions
|
||||
local_dims (int): Number of local dimensions
|
||||
first_reduce (int): Index of the first reduce dimension
|
||||
group_for_reduce (list): List of grouped dimensions for reduction
|
||||
upcast_in_mid_reduce_axes (set): Set of axes where upcasting occurs during mid-reduction
|
||||
shape_len (int): Length of the shape vector
|
||||
upcasted (int): Number of upcasted dimensions
|
||||
full_shape (list): Full shape of the object
|
||||
sts (list): List of objects with shapes
|
||||
|
||||
Returns:
|
||||
list: A list of color codes representing different types of dimensions.
|
||||
"""
|
||||
# first non local non reduce dims are global (blue)
|
||||
colors = (
|
||||
["blue"] * self.global_dims
|
||||
|
@ -433,6 +686,21 @@ class Kernel:
|
|||
return colors
|
||||
|
||||
def colored_shape(self, pad: Optional[int] = None, dense=False) -> str:
|
||||
"""
|
||||
Generate a string representation of the shape with each dimension colored according to its position.
|
||||
|
||||
:param pad: The number of spaces to pad the resulting string. If not provided, no padding is added.
|
||||
:type pad: Optional[int]
|
||||
:param dense: Whether or not to represent int dimensions in dense format (i.e., with 4 digits). Defaults to False.
|
||||
:type dense: bool
|
||||
:return: A string representation of the shape with each dimension colored according to its position.
|
||||
:rtype: str
|
||||
|
||||
Attributes:
|
||||
self.full_shape (List[Union[int, str]]): The full shape of the object.
|
||||
self.colors (Callable[[], List[str]]): A function that returns a list of colors for each dimension.
|
||||
ansilen (Callable[[str], int]): A function that calculates the length of a string in terminal characters.
|
||||
"""
|
||||
ret = " ".join(
|
||||
colored(s, color)
|
||||
for s, color in zip(
|
||||
|
@ -451,6 +719,16 @@ class Kernel:
|
|||
|
||||
# apply reshape and permute to all shapetrackers
|
||||
def reshape_and_permute(self, new_shape_fxn, axis):
|
||||
"""
|
||||
Apply reshape and permute to all shapetrackers.
|
||||
|
||||
Parameters:
|
||||
new_shape_fxn (function): Function used for reshaping.
|
||||
axis (int): Axis for permutation.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
new_sts = []
|
||||
for st in self.sts:
|
||||
if new_shape_fxn is not None:
|
||||
|
@ -462,6 +740,18 @@ class Kernel:
|
|||
|
||||
# drops the final dimension
|
||||
def upcast(self):
|
||||
"""
|
||||
Drop the final dimension.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
AssertionError: If the final dimension size is 1, as it cannot be upcasted.
|
||||
"""
|
||||
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
|
||||
self.upcasted += 1
|
||||
|
||||
|
@ -470,6 +760,18 @@ class Kernel:
|
|||
# top : if you want to pull that amount from the top
|
||||
# insert_before : place to insert the new stuff
|
||||
def shift_to(self, axis, amount, top=False, insert_before=None):
|
||||
"""
|
||||
Shift elements to a specified location.
|
||||
|
||||
Parameters:
|
||||
axis (int): The axis to pull from.
|
||||
amount (int): The amount to take.
|
||||
top (bool): If you want to pull that amount from the top. Default is False.
|
||||
insert_before (int): Place to insert the new stuff. Default is None, which means end of list.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if insert_before is None:
|
||||
insert_before = self.shape_len
|
||||
move_axis = axis if top else axis + 1
|
||||
|
@ -491,6 +793,13 @@ class Kernel:
|
|||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self) -> bool:
|
||||
"""
|
||||
Simplify by removing places where the shape is all ones. This function checks if the shape_len is 0 and then
|
||||
updates local_dims and upcasted values accordingly. It also reshapes and permutes the given shapes. The function
|
||||
returns True if any value in all_ones is True, else False.
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0:
|
||||
|
@ -506,6 +815,13 @@ class Kernel:
|
|||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
"""
|
||||
Simplify by merging adjacent dimensions. This function checks if the shape_len is 0 and then proceeds to
|
||||
merge dimensions when possible. It also handles special cases for image dtypes and updates the shapes and strides
|
||||
accordingly.
|
||||
|
||||
:return: None
|
||||
"""
|
||||
if self.shape_len == 0:
|
||||
return
|
||||
shapes, strides = [x.shape for x in self.sts], [
|
||||
|
@ -561,6 +877,16 @@ class Kernel:
|
|||
# ******************** GPU simplifiers ********************
|
||||
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
"""
|
||||
Limit the size of tensor dimensions.
|
||||
|
||||
:param x: Tuple of integers representing the shape of a tensor.
|
||||
:type x: Tuple[int]
|
||||
:param max_size: List of maximum allowed sizes for each dimension.
|
||||
:type max_size: List
|
||||
:return: Tuple of integers representing the new shape with dimensions limited by max_size.
|
||||
:rtype: Tuple[int, ...]
|
||||
"""
|
||||
new_shape, dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
next_idx = (i + 1) % dims
|
||||
|
@ -574,6 +900,14 @@ class Kernel:
|
|||
return tuple(new_shape)
|
||||
|
||||
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
|
||||
"""
|
||||
Limit dimensions to maximum allowed sizes.
|
||||
|
||||
:param global_max: List of maximum allowed global dimension sizes.
|
||||
:type global_max: List[int]
|
||||
:param local_max: List of maximum allowed local dimension sizes.
|
||||
:type local_max: List[int]
|
||||
"""
|
||||
# Check the global allocation limit, current the global_size will be flipped during codegen
|
||||
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
|
||||
global_dims = self.first_reduce - self.local_dims
|
||||
|
@ -605,6 +939,14 @@ class Kernel:
|
|||
)
|
||||
|
||||
def alias_buffer(self, i, pattern):
|
||||
"""
|
||||
Alias a buffer.
|
||||
|
||||
:param i: Index of the buffer to be aliased.
|
||||
:type i: int
|
||||
:param pattern: List representing the pattern for each shape.
|
||||
:type pattern: List
|
||||
"""
|
||||
assert len(pattern) == len(
|
||||
self.sts[i].shape
|
||||
), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
|
||||
|
@ -633,6 +975,30 @@ class Kernel:
|
|||
def apply_tensor_cores(
|
||||
self, use_tensor_cores=1, extra_opts: Optional[List[Opt]] = None
|
||||
):
|
||||
"""
|
||||
Apply tensor cores to the computation.
|
||||
|
||||
Attributes:
|
||||
use_tensor_cores (int): Flag indicating whether to apply tensor cores or not. Default is 1.
|
||||
extra_opts (Optional[List[Opt]]): Optional list of extra options. Default is None.
|
||||
|
||||
This function checks if the following conditions are met for applying tensor cores:
|
||||
1) use_tensor_cores flag is True.
|
||||
2) The current device has local memory support.
|
||||
3) Reduction operation exists and it's a summation (ReduceOps.SUM).
|
||||
4) The current device supports tensor cores.
|
||||
|
||||
If these conditions are met, the function iterates over all available tensor cores for the current device.
|
||||
It then checks if certain conditions hold true to apply tensor cores:
|
||||
1) Tensor core architecture is compatible with the current system.
|
||||
2) The reduction operation's source is a LazyOp and its operation is UnaryOps.CAST with the correct dtype_out.
|
||||
3) The multiplication operation (LazyOp with BinaryOps.MUL) exists and its sources are two LazyOps with
|
||||
BufferOps.LOAD operations and compatible dtypes with the tensor core configuration.
|
||||
4) The strides of both source buffers for the multiplication operation are zero for the first reduction dimension.
|
||||
5) The shape of the buffers is compatible with the tensor core dimensions.
|
||||
|
||||
If all these conditions are met, it selects the axes for buffer 0 and buffer 1 and applies tensor cores.
|
||||
"""
|
||||
if (
|
||||
use_tensor_cores
|
||||
and self.opts.has_local
|
||||
|
@ -714,6 +1080,18 @@ class Kernel:
|
|||
)
|
||||
|
||||
def fix(needed, ax):
|
||||
"""
|
||||
Fix function for tensor core operations.
|
||||
|
||||
This function is responsible for unrolling the reduce dimension and upcasting input tensor data type.
|
||||
It then creates a thread pattern based on the specified conditions.
|
||||
|
||||
Attributes:
|
||||
needed (bool): A flag to check if this operation is necessary.
|
||||
ax (int): The axis along which the reduction is performed.
|
||||
s0, s1 (float): Two values used for performing calculations.
|
||||
s0_exists, s1_exists (bool): Flags indicating whether `s0` and `s1` respectively are valid or not.
|
||||
"""
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed:
|
||||
return
|
||||
|
@ -788,6 +1166,23 @@ class Kernel:
|
|||
return False
|
||||
|
||||
def apply_opt(self, opt: Opt):
|
||||
"""
|
||||
Apply an optimization to the current object.
|
||||
|
||||
This method checks if the optimization operation is applicable based on the 'dont_use_locals' attribute and the type of the operation. It then appends the optimization to a list of applied optimizations. The axis for the optimization is calculated based on certain conditions and defaulted to -1 if no specific axis is given.
|
||||
|
||||
Args:
|
||||
opt (Opt): The optimization operation to apply.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'dont_use_locals' attribute is True and the optimization operation is one of LOCAL, LASTLOCAL, GROUP, GROUPTOP, or UPCASTMID.
|
||||
|
||||
Attributes:
|
||||
applied_opts (List[Opt]): A list of previously applied optimization operations.
|
||||
dont_use_locals (bool): If True, some optimization operations are not allowed.
|
||||
first_reduce (int): The index of the first reduction operation.
|
||||
group_for_reduce (list): A list of groups for reduction operations.
|
||||
"""
|
||||
assert not self.dont_use_locals or opt.op not in {
|
||||
OptOps.LOCAL,
|
||||
OptOps.LASTLOCAL,
|
||||
|
@ -920,6 +1315,14 @@ class Kernel:
|
|||
return self.simplify_ones()
|
||||
|
||||
def hand_coded_optimizations(self):
|
||||
"""
|
||||
This method handles the application of hand-coded optimizations.
|
||||
|
||||
Attributes:
|
||||
MV_BLOCKSIZE (int): The block size for matrix-vector multiplication.
|
||||
MV_THREADS_PER_ROW (int): The number of threads per row for matrix-vector multiplication.
|
||||
MV_ROWS_PER_THREAD (int): The number of rows per thread for matrix-vector multiplication.
|
||||
"""
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = (
|
||||
getenv("MV_BLOCKSIZE", 4),
|
||||
|
|
Loading…
Reference in New Issue