speedups from llama branch
parent
0b03216cc3
commit
22905dd657
|
@ -172,6 +172,14 @@ class ShapeTracker:
|
|||
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views
|
||||
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
|
||||
old_strides = [y for x,y in zip(self.shape, self.strides) if x != 1]
|
||||
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
|
||||
self.views[-1] = View(new_shape, new_strides_tuple, self.offset)
|
||||
return self
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
|
|
|
@ -110,6 +110,7 @@ class Node:
|
|||
nodes.append(NumNode(sum([x.b for x in num_nodes])))
|
||||
|
||||
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
|
||||
# TODO: this is slow!
|
||||
nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode))
|
||||
mul_nodes += [MulNode(x, 1) for x in nodes]
|
||||
mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!)
|
||||
|
|
Loading…
Reference in New Issue