parent
eb595588bb
commit
e9426f4fe4
|
@ -599,19 +599,19 @@ class TestGetContraction(unittest.TestCase):
|
|||
self.assertEqual(r, [[0], [1, 2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4), (1,2,3,4))
|
||||
self.assertEqual(r, [[0], [1], [2], [3, 4]])
|
||||
self.assertEqual(r, [[], [0, 1], [2], [3, 4]])
|
||||
|
||||
r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3*4))
|
||||
self.assertEqual(r, [[0], [1], [2, 3]])
|
||||
self.assertEqual(r, [[], [0, 1], [2, 3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (2,1,3,4))
|
||||
self.assertEqual(r, [[0, 1], [], [2], [3]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,1,2*3*4,1))
|
||||
self.assertEqual(r, [[0], [], [1,2,3], []])
|
||||
self.assertEqual(r, [[], [], [0,1,2,3], []])
|
||||
|
||||
r = get_contraction((2,1,3,4), (1,2,3,4))
|
||||
self.assertEqual(r, [[], [0], [1, 2], [3]])
|
||||
|
@ -626,7 +626,7 @@ class TestGetContraction(unittest.TestCase):
|
|||
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
|
||||
|
||||
r = get_contraction((1,2,3,4), (1,2,3,4,1))
|
||||
self.assertEqual(r, [[0], [1], [2], [3], []])
|
||||
self.assertEqual(r, [[], [0, 1], [2], [3], []])
|
||||
|
||||
r = get_contraction((14,1,384,14,1,1,1,1), (1,14,384,14))
|
||||
self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7]])
|
||||
|
@ -642,22 +642,22 @@ class TestGetContraction(unittest.TestCase):
|
|||
|
||||
def test_contraction_ones(self):
|
||||
r = get_contraction((1,), (1,1,1))
|
||||
self.assertEqual(r, [[0], [], []])
|
||||
self.assertEqual(r, [[], [], [0]])
|
||||
|
||||
r = get_contraction((1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], []])
|
||||
self.assertEqual(r, [[], [], [0, 1]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,))
|
||||
self.assertEqual(r, [[0,1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1))
|
||||
self.assertEqual(r, [[0], [1,2,3]])
|
||||
self.assertEqual(r, [[], [0,1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2,3]])
|
||||
self.assertEqual(r, [[], [], [0,1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2], [3]])
|
||||
self.assertEqual(r, [[], [], [], [0,1,2,3]])
|
||||
|
||||
class TestShapeTrackerSize(unittest.TestCase):
|
||||
def test_simple_size(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union, Iterable
|
||||
from tinygrad.ops import MovementOps
|
||||
|
@ -177,21 +177,7 @@ class ShapeTracker:
|
|||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
# TODO: if we remove movementops from lazy.py we can delete this
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
# Pre-allocate all groups.
|
||||
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
# Index for new_shape and axis_groups.
|
||||
i: int = 0
|
||||
old_shape_i: int = 0
|
||||
while old_shape_i < len(old_shape):
|
||||
# 1s exist in new_shape only will lead to empty axes group creations.
|
||||
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
else:
|
||||
axis_groups[i].append(old_shape_i)
|
||||
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
|
||||
# Move to next axes group if total size of all dimensions match.
|
||||
if axis_group_size == new_shape[i]:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
elif axis_group_size > new_shape[i]: return None
|
||||
old_shape_i += 1
|
||||
return axis_groups
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
|
Loading…
Reference in New Issue