nn init docstrings
parent
45dbac0a02
commit
35b37b1a5a
|
@ -6,6 +6,17 @@ from tinygrad.nn import optim, state # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm2d:
|
class BatchNorm2d:
|
||||||
|
"""
|
||||||
|
This class is for 2D batch normalization. It scales and shifts the input tensor x using the calculated mean and variance.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sz (int): The size of the tensor.
|
||||||
|
eps (float): A small constant added to the variance to prevent division by zero. Default is 1e-5.
|
||||||
|
affine (bool): A boolean value that when set to True, the constructor learns and applies scale and bias. Default is True.
|
||||||
|
track_running_stats (bool): A boolean value that when set to True, this module tracks the running mean and variance. Default is True.
|
||||||
|
momentum (float): The value used for the moving average of the mean and variance. Default is 0.1.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, sz: int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1
|
self, sz: int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1
|
||||||
):
|
):
|
||||||
|
@ -26,6 +37,15 @@ class BatchNorm2d:
|
||||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
This function performs the forward pass of the batch normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The normalized output tensor.
|
||||||
|
"""
|
||||||
if Tensor.training:
|
if Tensor.training:
|
||||||
# This requires two full memory accesses to x
|
# This requires two full memory accesses to x
|
||||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||||
|
@ -73,6 +93,34 @@ def Conv1d(
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Function Signature: Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
|
||||||
|
|
||||||
|
**Description**:
|
||||||
|
This function is a wrapper for the PyTorch convolutional layer (Conv2d). It takes in parameters like in_channels, out_channels, kernel_size, stride, padding, dilation, groups, and bias. The kernel_size is automatically converted into a tuple of size 1 for the Conv2d function.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
|
||||||
|
:param in_channels: Number of channels in the input image
|
||||||
|
:type in_channels: int
|
||||||
|
:param out_channels: Number of channels produced by the convolution
|
||||||
|
:type out_channels: int
|
||||||
|
:param kernel_size: Size of the convolving kernel
|
||||||
|
:type kernel_size: int
|
||||||
|
:param stride: Stride of the convolution, default is 1
|
||||||
|
:type stride: int, optional
|
||||||
|
:param padding: Zero-padding added to both sides of the input, default is 0
|
||||||
|
:type padding: int, optional
|
||||||
|
:param dilation: Spacing between the kernel points, default is 1
|
||||||
|
:type dilation: int, optional
|
||||||
|
:param groups: Number of blocked connections from input channels to output channels, default is 1
|
||||||
|
:type groups: int, optional
|
||||||
|
:param bias: If True, adds a learnable bias to the output, default is True
|
||||||
|
:type bias: bool, optional
|
||||||
|
|
||||||
|
**Returns**:
|
||||||
|
The Conv2d layer with converted parameters
|
||||||
|
"""
|
||||||
return Conv2d(
|
return Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -86,6 +134,20 @@ def Conv1d(
|
||||||
|
|
||||||
|
|
||||||
class Conv2d:
|
class Conv2d:
|
||||||
|
"""
|
||||||
|
This class defines a 2D convolutional layer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
kernel_size (tuple): Size of the convolving kernel.
|
||||||
|
stride (int): Stride of the convolution. Default is 1.
|
||||||
|
padding (int): Padding added to both sides of the input. Default is 0.
|
||||||
|
dilation (int): Spacing between kernel elements. Default is 1.
|
||||||
|
groups (int): Number of blocked connections from input channels to output channels. Default is 1.
|
||||||
|
bias (bool): If True, adds a learnable bias to the output. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -116,6 +178,15 @@ class Conv2d:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass through the layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor after convolution of shape (batch_size, out_channels, new_height, new_width).
|
||||||
|
"""
|
||||||
return x.conv2d(
|
return x.conv2d(
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
|
@ -126,6 +197,17 @@ class Conv2d:
|
||||||
)
|
)
|
||||||
|
|
||||||
def initialize_weight(self, out_channels, in_channels, groups):
|
def initialize_weight(self, out_channels, in_channels, groups):
|
||||||
|
"""
|
||||||
|
Initialize the weight tensor using the Kaiming uniform initialization method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_channels (int): Number of output channels.
|
||||||
|
in_channels (int): Number of input channels per group.
|
||||||
|
groups (int): Number of blocked connections from input channels to output channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Initialized weight tensor of shape (out_channels, in_channels // groups, *kernel_size).
|
||||||
|
"""
|
||||||
return Tensor.kaiming_uniform(
|
return Tensor.kaiming_uniform(
|
||||||
out_channels, in_channels // groups, *self.kernel_size, a=math.sqrt(5)
|
out_channels, in_channels // groups, *self.kernel_size, a=math.sqrt(5)
|
||||||
)
|
)
|
||||||
|
@ -142,6 +224,36 @@ def ConvTranspose1d(
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Function Signature: ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True)
|
||||||
|
|
||||||
|
**Description**:
|
||||||
|
This function creates a 1D convolutional transpose layer.
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
Attributes:
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels.
|
||||||
|
kernel_size : int
|
||||||
|
Size of the convolving kernel.
|
||||||
|
stride : int, default=1
|
||||||
|
Stride of the convolution.
|
||||||
|
padding : int, default=0
|
||||||
|
Zero-padding added to both sides of the input.
|
||||||
|
output_padding : int, default=0
|
||||||
|
Additional size added to one side of the output shape.
|
||||||
|
dilation : int, default=1
|
||||||
|
Spacing between kernel elements.
|
||||||
|
groups : int, default=1
|
||||||
|
Number of blocked connections from input channels to output channels.
|
||||||
|
bias : bool, default=True
|
||||||
|
If True, adds a learnable bias to the output.
|
||||||
|
|
||||||
|
**Returns**:
|
||||||
|
ConvTranspose2d object with specified attributes.
|
||||||
|
"""
|
||||||
return ConvTranspose2d(
|
return ConvTranspose2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -156,6 +268,21 @@ def ConvTranspose1d(
|
||||||
|
|
||||||
|
|
||||||
class ConvTranspose2d(Conv2d):
|
class ConvTranspose2d(Conv2d):
|
||||||
|
"""
|
||||||
|
This class defines a 2D transposed convolution layer, often used in image segmentation tasks.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
kernel_size (tuple): The size of the convolving kernel.
|
||||||
|
stride (int): Stride of the convolution. Defaults to 1.
|
||||||
|
padding (int): Zero-padding added to both sides of the input. Defaults to 0.
|
||||||
|
output_padding (int): Additional size added to one side of the output shape. Defaults to 0.
|
||||||
|
dilation (int): Spacing between kernel elements. Defaults to 1.
|
||||||
|
groups (int): Number of blocked connections from input channels to output channels. Defaults to 1.
|
||||||
|
bias (bool): If True, adds a learnable bias to the output. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -181,6 +308,17 @@ class ConvTranspose2d(Conv2d):
|
||||||
self.output_padding = output_padding
|
self.output_padding = output_padding
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Forward pass method for the ConvTranspose2d layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor of shape (batch_size, out_channels, height', width'), where height' and width' are computed by:
|
||||||
|
output_height = (height - 1) * stride + kernel_size - 2 * padding + output_padding
|
||||||
|
output_width = (width - 1) * stride + kernel_size - 2 * padding + output_padding
|
||||||
|
"""
|
||||||
return x.conv_transpose2d(
|
return x.conv_transpose2d(
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
|
@ -192,13 +330,43 @@ class ConvTranspose2d(Conv2d):
|
||||||
)
|
)
|
||||||
|
|
||||||
def initialize_weight(self, out_channels, in_channels, groups):
|
def initialize_weight(self, out_channels, in_channels, groups):
|
||||||
|
"""
|
||||||
|
Initializes the weight tensor using the Kaiming uniform initialization method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
groups (int): Number of blocked connections from input channels to output channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Initialized weight tensor with shape (out_channels, in_channels // groups, *kernel_size).
|
||||||
|
"""
|
||||||
return Tensor.kaiming_uniform(
|
return Tensor.kaiming_uniform(
|
||||||
in_channels, out_channels // groups, *self.kernel_size, a=math.sqrt(5)
|
in_channels, out_channels // groups, *self.kernel_size, a=math.sqrt(5)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Linear:
|
class Linear:
|
||||||
|
"""
|
||||||
|
A class representing a linear transformation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
in_features (int): The number of input features.
|
||||||
|
out_features (int): The number of output features.
|
||||||
|
bias (bool, optional): If True, adds a learnable bias to the output. Defaults to True.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True):
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
"""
|
||||||
|
Initializes the Linear class with the given input and output features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): The number of input features.
|
||||||
|
out_features (int): The number of output features.
|
||||||
|
bias (bool, optional): If True, adds a learnable bias to the output. Defaults to True.
|
||||||
|
|
||||||
|
"""
|
||||||
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5))
|
||||||
# TODO: remove this once we can represent Tensor with int shape in typing
|
# TODO: remove this once we can represent Tensor with int shape in typing
|
||||||
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
|
assert isinstance(self.weight.shape[1], int), "does not support symbolic shape"
|
||||||
|
@ -208,10 +376,30 @@ class Linear:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Applies the linear transformation to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor after applying the linear transformation.
|
||||||
|
|
||||||
|
"""
|
||||||
return x.linear(self.weight.transpose(), self.bias)
|
return x.linear(self.weight.transpose(), self.bias)
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm:
|
class GroupNorm:
|
||||||
|
"""
|
||||||
|
Apply group normalization to the input tensor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_groups (int): The number of groups to separate the channels into.
|
||||||
|
num_channels (int): The number of channels in the input tensor.
|
||||||
|
eps (float): A small constant added to the variance to prevent division by zero. Default is 1e-5.
|
||||||
|
affine (bool): A boolean value that when set to True, this module has learnable affine parameters. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True
|
self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True
|
||||||
):
|
):
|
||||||
|
@ -220,6 +408,15 @@ class GroupNorm:
|
||||||
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Normalize the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor of shape [batch_size, num_channels, ...].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The normalized tensor with the same shape as the input tensor.
|
||||||
|
"""
|
||||||
# reshape for layernorm to work as group norm
|
# reshape for layernorm to work as group norm
|
||||||
# subtract mean and divide stddev
|
# subtract mean and divide stddev
|
||||||
x = (
|
x = (
|
||||||
|
@ -237,12 +434,30 @@ class GroupNorm:
|
||||||
|
|
||||||
|
|
||||||
class InstanceNorm:
|
class InstanceNorm:
|
||||||
|
"""
|
||||||
|
Class that implements Instance Normalization for tensors.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_features (int): The number of features in the tensor.
|
||||||
|
eps (float): A small constant added to the standard deviation to avoid division by zero. Default is 1e-5.
|
||||||
|
affine (bool): If True, apply learned scale and shift parameters. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, num_features: int, eps: float = 1e-5, affine: bool = True):
|
def __init__(self, num_features: int, eps: float = 1e-5, affine: bool = True):
|
||||||
self.num_features, self.eps = num_features, eps
|
self.num_features, self.eps = num_features, eps
|
||||||
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
||||||
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Normalize the input tensor using instance normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The normalized tensor. If affine is False, returns the original tensor after normalizing its channels.
|
||||||
|
"""
|
||||||
x = (
|
x = (
|
||||||
x.reshape(x.shape[0], self.num_features, -1)
|
x.reshape(x.shape[0], self.num_features, -1)
|
||||||
.layernorm(eps=self.eps)
|
.layernorm(eps=self.eps)
|
||||||
|
@ -256,12 +471,29 @@ class InstanceNorm:
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm:
|
class LayerNorm:
|
||||||
|
"""
|
||||||
|
Implements layer normalization.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
normalized_shape (Union[int, Tuple[int, ...]]): The shape of the input tensor to be normalized.
|
||||||
|
eps (float): A small constant added to the variance to prevent division by zero. Default is 1e-5.
|
||||||
|
elementwise_affine (bool): If True, apply learned scale and shift parameters to the output. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape: Union[int, Tuple[int, ...]],
|
normalized_shape: Union[int, Tuple[int, ...]],
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
elementwise_affine: bool = True,
|
elementwise_affine: bool = True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initializes the layer normalization instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (Union[int, Tuple[int, ...]]): The shape of the input tensor to be normalized.
|
||||||
|
eps (float): A small constant added to the variance to prevent division by zero. Default is 1e-5.
|
||||||
|
elementwise_affine (bool): If True, apply learned scale and shift parameters to the output. Default is True.
|
||||||
|
"""
|
||||||
self.normalized_shape = (
|
self.normalized_shape = (
|
||||||
(normalized_shape,)
|
(normalized_shape,)
|
||||||
if isinstance(normalized_shape, int)
|
if isinstance(normalized_shape, int)
|
||||||
|
@ -279,6 +511,19 @@ class LayerNorm:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Normalizes the input tensor using layer normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The normalized tensor. If elementwise_affine is True, then the output tensor is multiplied by
|
||||||
|
learned scale and added to the shift parameter.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the last dimensions of x do not match normalized_shape.
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
self.normalized_shape == x.shape[-len(self.normalized_shape) :]
|
self.normalized_shape == x.shape[-len(self.normalized_shape) :]
|
||||||
), f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
), f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||||
|
@ -289,16 +534,72 @@ class LayerNorm:
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(LayerNorm):
|
class LayerNorm2d(LayerNorm):
|
||||||
|
"""
|
||||||
|
LayerNorm2d class for 2D Layer Normalization.
|
||||||
|
|
||||||
|
This class is a subclass of the LayerNorm class and is used to normalize the input tensor x in the __call__ method.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
x (Tensor): The input tensor to be normalized.
|
||||||
|
"""
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
"""
|
||||||
|
Normalize the input tensor x.
|
||||||
|
|
||||||
|
This method is used to normalize the input tensor x by permuting its dimensions and calling the parent class's
|
||||||
|
__call__ method on the result. The dimensions are then permuted back for the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input tensor to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The normalized tensor after permuting its dimensions back.
|
||||||
|
"""
|
||||||
return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
|
"""
|
||||||
|
Embedding class for word embeddings.
|
||||||
|
|
||||||
|
This class is used to convert input indices into their corresponding weighted vectors. It has methods for initializing
|
||||||
|
the weights and performing the embedding lookups.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
vocab_size (int): The size of the vocabulary.
|
||||||
|
embed_size (int): The size of the embedding vectors.
|
||||||
|
weight (Tensor): The tensor containing the embedding weights.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, embed_size: int):
|
def __init__(self, vocab_size: int, embed_size: int):
|
||||||
|
"""
|
||||||
|
Initialize the Embedding object.
|
||||||
|
|
||||||
|
This method initializes the vocab_size and embed_size attributes, as well as the weight tensor that holds the
|
||||||
|
embedding weights. The weight tensor is initialized using Glorot uniform initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (int): The size of the vocabulary.
|
||||||
|
embed_size (int): The size of the embedding vectors.
|
||||||
|
"""
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||||
|
|
||||||
def __call__(self, idx: Tensor) -> Tensor:
|
def __call__(self, idx: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Perform the embedding lookup for the input indices.
|
||||||
|
|
||||||
|
This method performs an embedding lookup on the input tensor idx by creating a binary matrix that has ones at
|
||||||
|
positions where the vocabulary index matches the corresponding index in idx, and zeros elsewhere. This matrix is
|
||||||
|
then used to select weight vectors from the weight tensor using matrix multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (Tensor): The input tensor of indices to be looked up in the embedding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor after performing the embedding lookup.
|
||||||
|
"""
|
||||||
if not hasattr(self, "vocab_counter"):
|
if not hasattr(self, "vocab_counter"):
|
||||||
self.vocab_counter = Tensor.arange(
|
self.vocab_counter = Tensor.arange(
|
||||||
self.vocab_size, requires_grad=False
|
self.vocab_size, requires_grad=False
|
||||||
|
|
Loading…
Reference in New Issue