Reduce tensor dot line count and fixed 1d tensor dot (#1045)
* fixed tensor.dot * no 1d dot for image=1 * shorter lines * add 3d dot testspull/1046/head
parent
9c6e507518
commit
6ff720103e
|
@ -279,8 +279,16 @@ class TestOps(unittest.TestCase):
|
|||
return x*torch.tanh(torch.nn.functional.softplus(x))
|
||||
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
|
||||
def test_dot_1d(self):
|
||||
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(32,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(65), (32,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(32,45,65), (32,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
|
|
|
@ -469,10 +469,9 @@ class Tensor:
|
|||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
|
||||
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
|
||||
r = (x*w).sum(-1)
|
||||
return r.reshape((*r.shape[:-2], r.shape[-1])) if len(self.shape) == 1 else r
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
def cumsum(self, axis=0):
|
||||
x = self.permute(*(i for i in range(self.ndim) if i != axis), axis)
|
||||
|
|
Loading…
Reference in New Issue