223 lines
9.7 KiB
ReStructuredText
223 lines
9.7 KiB
ReStructuredText
torch.nested
|
|
============
|
|
|
|
.. automodule:: torch.nested
|
|
|
|
Introduction
|
|
++++++++++++
|
|
|
|
.. warning::
|
|
|
|
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
|
|
|
|
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
|
|
|
|
The only constraint on the input Tensors is that their dimension must match.
|
|
|
|
This enables more efficient metadata representations and access to purpose built kernels.
|
|
|
|
One application of NestedTensors is to express sequential data in various domains.
|
|
While the conventional approach is to pad variable length sequences, NestedTensor
|
|
enables users to bypass padding. The API for calling operations on a nested tensor is no different
|
|
from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models,
|
|
with the main difference being :ref:`construction of the inputs <construction>`.
|
|
|
|
As this is a prototype feature, the :ref:`operations supported <supported operations>` are still
|
|
limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found
|
|
`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
|
|
|
|
.. _construction:
|
|
|
|
Construction
|
|
++++++++++++
|
|
|
|
Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor``
|
|
constructor.
|
|
|
|
>>> a, b = torch.arange(3), torch.arange(5) + 3
|
|
>>> a
|
|
tensor([0, 1, 2])
|
|
>>> b
|
|
tensor([3, 4, 5, 6, 7])
|
|
>>> nt = torch.nested.nested_tensor([a, b])
|
|
>>> nt
|
|
nested_tensor([
|
|
tensor([0, 1, 2]),
|
|
tensor([3, 4, 5, 6, 7])
|
|
])
|
|
|
|
Data type, device and whether gradients are required can be chosen via the usual keyword arguments.
|
|
|
|
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
|
|
>>> nt
|
|
nested_tensor([
|
|
tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
|
|
tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
|
|
], device='cuda:0', requires_grad=True)
|
|
|
|
In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd
|
|
history from the tensors passed to the constructor. For more information, refer to the section on
|
|
:ref:`constructor functions`.
|
|
|
|
In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.
|
|
|
|
>>> a = torch.randn(3, 50, 70) # image 1
|
|
>>> b = torch.randn(3, 128, 64) # image 2
|
|
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
|
>>> nt.dim()
|
|
4
|
|
|
|
If one of the dimensions doesn't match, the constructor throws an error.
|
|
|
|
>>> a = torch.randn(50, 128) # text 1
|
|
>>> b = torch.randn(3, 128, 64) # image 2
|
|
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
|
Traceback (most recent call last):
|
|
File "<stdin>", line 1, in <module>
|
|
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
|
|
|
|
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
|
|
NestedTensor allocates new memory to store them and does not keep a reference.
|
|
|
|
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
|
|
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
|
|
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
|
|
has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that
|
|
we can track it and plan accordingly.
|
|
|
|
size
|
|
+++++++++++++++++++++++++
|
|
|
|
Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular.
|
|
|
|
>>> a = torch.randn(50, 128) # text 1
|
|
>>> b = torch.randn(32, 128) # text 2
|
|
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
|
>>> nt.size(0)
|
|
2
|
|
>>> nt.size(1)
|
|
Traceback (most recent call last):
|
|
File "<stdin>", line 1, in <module>
|
|
RuntimeError: Given dimension 1 is irregular and does not have a size.
|
|
>>> nt.size(2)
|
|
128
|
|
|
|
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``.
|
|
|
|
>>> a = torch.randn(20, 128) # text 1
|
|
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
|
|
>>> nt.size(0)
|
|
2
|
|
>>> nt.size(1)
|
|
20
|
|
>>> nt.size(2)
|
|
128
|
|
>>> torch.stack(nt.unbind()).size()
|
|
torch.Size([2, 20, 128])
|
|
>>> torch.stack([a, a]).size()
|
|
torch.Size([2, 20, 128])
|
|
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
|
|
True
|
|
|
|
In the future we might make it easier to detect this condition and convert seamlessly.
|
|
|
|
Please open a feature request if you have a need for this (or any other related feature for that matter).
|
|
|
|
unbind
|
|
+++++++++++++++++++++++++
|
|
|
|
``unbind`` allows you to retrieve a view of the constituents.
|
|
|
|
>>> import torch
|
|
>>> a = torch.randn(2, 3)
|
|
>>> b = torch.randn(3, 4)
|
|
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
|
>>> nt
|
|
nested_tensor([
|
|
tensor([[ 1.2286, -1.2343, -1.4842],
|
|
[-0.7827, 0.6745, 0.0658]]),
|
|
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
|
[-0.2871, -0.2980, 0.5559, 1.9885],
|
|
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
|
])
|
|
>>> nt.unbind()
|
|
(tensor([[ 1.2286, -1.2343, -1.4842],
|
|
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
|
[-0.2871, -0.2980, 0.5559, 1.9885],
|
|
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
|
|
>>> nt.unbind()[0] is not a
|
|
True
|
|
>>> nt.unbind()[0].mul_(3)
|
|
tensor([[ 3.6858, -3.7030, -4.4525],
|
|
[-2.3481, 2.0236, 0.1975]])
|
|
>>> nt
|
|
nested_tensor([
|
|
tensor([[ 3.6858, -3.7030, -4.4525],
|
|
[-2.3481, 2.0236, 0.1975]]),
|
|
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
|
[-0.2871, -0.2980, 0.5559, 1.9885],
|
|
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
|
])
|
|
|
|
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
|
|
|
|
.. _constructor functions:
|
|
|
|
Nested tensor constructor and conversion functions
|
|
++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
The following functions are related to nested tensors:
|
|
|
|
.. currentmodule:: torch.nested
|
|
|
|
.. autofunction:: nested_tensor
|
|
.. autofunction:: as_nested_tensor
|
|
.. autofunction:: to_padded_tensor
|
|
|
|
.. _supported operations:
|
|
|
|
Supported operations
|
|
++++++++++++++++++++++++++
|
|
|
|
In this section, we summarize the operations that are currently supported on
|
|
NestedTensor and any constraints they have.
|
|
|
|
.. csv-table::
|
|
:header: "PyTorch operation", "Constraints"
|
|
:widths: 30, 55
|
|
:delim: ;
|
|
|
|
:func:`torch.matmul`; "Supports matrix multiplication between two (>= 3d) nested tensors where
|
|
the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size
|
|
(i.e. no broadcasting support for batch dimensions yet)."
|
|
:func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors."
|
|
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
|
|
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
|
|
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.relu`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.silu`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.abs`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.sgn`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.logical_not`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.neg`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
|
|
:func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor."
|
|
:func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar."
|
|
:func:`torch.select`; "Supports selecting along all dimensions."
|
|
:func:`torch.clone`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.detach`; "Behavior is the same as on regular tensors."
|
|
:func:`torch.unbind`; "Supports unbinding along ``dim=0`` only."
|
|
:func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed).
|
|
Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited.
|
|
In particular, the only valid size for a irregular dimension is ``-1``.
|
|
Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``."
|
|
:func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``."
|
|
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
|
|
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
|
|
:func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
|
|
:func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
|
|
:func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
|
|
:func:`torch.nn.LayerNorm`; "The ``normalized_shape`` argument is restricted to not extend into the irregular dimensions of the NestedTensor."
|