1
0
Fork 0

Reduce tensor dot line count and fixed 1d tensor dot (#1045)

* fixed tensor.dot

* no 1d dot for image=1

* shorter lines

* add 3d dot tests
pull/1046/head
Francesco Castelli 2023-06-25 19:32:45 +02:00 committed by GitHub
parent 9c6e507518
commit 6ff720103e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 4 deletions

View File

@ -279,8 +279,16 @@ class TestOps(unittest.TestCase):
return x*torch.tanh(torch.nn.functional.softplus(x))
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
def test_dot_1d(self):
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(32,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(65), (32,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(32,45,65), (32,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
a.matmul(a)

View File

@ -469,10 +469,9 @@ class Tensor:
def dot(self, w:Tensor) -> Tensor:
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
r = (x*w).sum(-1)
return r.reshape((*r.shape[:-2], r.shape[-1])) if len(self.shape) == 1 else r
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1)
def cumsum(self, axis=0):
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)