1
0
Fork 0

speedups from llama branch

pull/682/head
George Hotz 2023-03-10 22:01:32 -08:00
parent 0b03216cc3
commit 22905dd657
2 changed files with 9 additions and 0 deletions

View File

@ -172,6 +172,14 @@ class ShapeTracker:
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# check if this is adding or removing 1s (only)
# NOTE: this is optional, but removes most calls to (expensive!) merge_views
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
return self
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else:

View File

@ -110,6 +110,7 @@ class Node:
nodes.append(NumNode(sum([x.b for x in num_nodes])))
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
# TODO: this is slow!
nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode))
mul_nodes += [MulNode(x, 1) for x in nodes]
mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!)