1
0
Fork 0

simpler get_contraction (#2552)

* simpler get_contraction

* and test
pull/2556/head
chenyu 2023-12-01 18:02:52 -05:00 committed by GitHub
parent eb595588bb
commit e9426f4fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 28 deletions

View File

@ -599,19 +599,19 @@ class TestGetContraction(unittest.TestCase):
self.assertEqual(r, [[0], [1, 2], [3]])
r = get_contraction((1,2,3,1,4), (1,2,3,4))
self.assertEqual(r, [[0], [1], [2], [3, 4]])
self.assertEqual(r, [[], [0, 1], [2], [3, 4]])
r = get_contraction((1,2,3,1,4,1,1), (2,3,4))
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
r = get_contraction((1,2,3,4), (1,2,3*4))
self.assertEqual(r, [[0], [1], [2, 3]])
self.assertEqual(r, [[], [0, 1], [2, 3]])
r = get_contraction((1,2,3,4), (2,1,3,4))
self.assertEqual(r, [[0, 1], [], [2], [3]])
r = get_contraction((1,2,3,4), (1,1,2*3*4,1))
self.assertEqual(r, [[0], [], [1,2,3], []])
self.assertEqual(r, [[], [], [0,1,2,3], []])
r = get_contraction((2,1,3,4), (1,2,3,4))
self.assertEqual(r, [[], [0], [1, 2], [3]])
@ -626,7 +626,7 @@ class TestGetContraction(unittest.TestCase):
self.assertEqual(r, [[0, 1], [2], [3, 4, 5, 6]])
r = get_contraction((1,2,3,4), (1,2,3,4,1))
self.assertEqual(r, [[0], [1], [2], [3], []])
self.assertEqual(r, [[], [0, 1], [2], [3], []])
r = get_contraction((14,1,384,14,1,1,1,1), (1,14,384,14))
self.assertEqual(r, [[], [0], [1,2], [3,4,5,6,7]])
@ -642,22 +642,22 @@ class TestGetContraction(unittest.TestCase):
def test_contraction_ones(self):
r = get_contraction((1,), (1,1,1))
self.assertEqual(r, [[0], [], []])
self.assertEqual(r, [[], [], [0]])
r = get_contraction((1,1), (1,1,1))
self.assertEqual(r, [[0], [1], []])
self.assertEqual(r, [[], [], [0, 1]])
r = get_contraction((1,1,1,1), (1,))
self.assertEqual(r, [[0,1,2,3]])
r = get_contraction((1,1,1,1), (1,1))
self.assertEqual(r, [[0], [1,2,3]])
self.assertEqual(r, [[], [0,1,2,3]])
r = get_contraction((1,1,1,1), (1,1,1))
self.assertEqual(r, [[0], [1], [2,3]])
self.assertEqual(r, [[], [], [0,1,2,3]])
r = get_contraction((1,1,1,1), (1,1,1,1))
self.assertEqual(r, [[0], [1], [2], [3]])
self.assertEqual(r, [[], [], [], [0,1,2,3]])
class TestShapeTrackerSize(unittest.TestCase):
def test_simple_size(self):

View File

@ -1,6 +1,6 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
import functools, itertools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union, Iterable
from tinygrad.ops import MovementOps
@ -177,21 +177,7 @@ class ShapeTracker:
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
# TODO: if we remove movementops from lazy.py we can delete this
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
# Pre-allocate all groups.
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
# Index for new_shape and axis_groups.
i: int = 0
old_shape_i: int = 0
while old_shape_i < len(old_shape):
# 1s exist in new_shape only will lead to empty axes group creations.
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
if i < len(new_shape) - 1: i += 1
else:
axis_groups[i].append(old_shape_i)
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
# Move to next axes group if total size of all dimensions match.
if axis_group_size == new_shape[i]:
if i < len(new_shape) - 1: i += 1
elif axis_group_size > new_shape[i]: return None
old_shape_i += 1
return axis_groups
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]