78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import torch
|
|
|
|
from . import _Tensor, Tensor
|
|
from .reference import _dims, _enable_layers, llist, ltuple
|
|
|
|
|
|
class DelayedMulTensor(_Tensor):
|
|
def __init__(self, lhs, rhs):
|
|
self._lhs, self._rhs = lhs, rhs
|
|
self._data = None
|
|
self._levels_data = None
|
|
self._has_device = lhs._has_device or rhs._has_device
|
|
self._batchtensor_data = None
|
|
self._tensor_data = None
|
|
|
|
@property
|
|
def _levels(self):
|
|
if self._levels_data is None:
|
|
levels = llist(self._lhs._levels)
|
|
for l in self._rhs._levels:
|
|
if l not in levels:
|
|
levels.append(l)
|
|
self._levels_data = ltuple(levels)
|
|
return self._levels_data
|
|
|
|
@property
|
|
def _batchtensor(self):
|
|
if self._batchtensor_data is None:
|
|
with _enable_layers(self._levels):
|
|
print("bt multiply fallback")
|
|
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
|
|
return self._batchtensor_data
|
|
|
|
@property
|
|
def _tensor(self):
|
|
if self._tensor_data is None:
|
|
self._tensor_data = Tensor.from_batched(
|
|
self._batchtensor, self._has_device
|
|
)._tensor
|
|
return self._tensor_data
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self._batchtensor.ndim
|
|
|
|
@property
|
|
def dims(self):
|
|
return ltuple(super().dims)
|
|
|
|
def sum(self, dim):
|
|
dims = _dims(dim, 0, False, False)
|
|
n = ord("a")
|
|
all_levels = self._levels
|
|
|
|
def to_char(d):
|
|
return chr(n + all_levels.index(d))
|
|
|
|
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
|
|
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
|
|
new_dims = tuple(d for d in self.dims if d not in dims)
|
|
new_levels = [l for l in self._levels if l not in dims]
|
|
fmt = "".join(
|
|
[
|
|
*(to_char(d) for d in levelslhs),
|
|
",",
|
|
*(to_char(d) for d in levelsrhs),
|
|
"->",
|
|
*(to_char(d) for d in new_levels),
|
|
]
|
|
)
|
|
result_data = torch.einsum(fmt, (plhs, prhs))
|
|
return Tensor.from_positional(result_data, new_levels, True)
|