387 lines
11 KiB
Python
387 lines
11 KiB
Python
# this will be the new test_ops for the next level
|
|
# schedule confirms the right things are capable of fusing
|
|
# NOTE: this has overlap with external_test_opt.py
|
|
|
|
import unittest
|
|
from typing import List, Optional
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.ops import LoadOps
|
|
from tinygrad.device import Device, Compiled
|
|
from tinygrad.helpers import DEBUG, dtypes
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from tinygrad.graph import log_schedule_item, print_tree
|
|
from tinygrad import nn
|
|
|
|
|
|
def check_schedule(
|
|
t: Tensor,
|
|
allowed: int,
|
|
to_prerealize: Optional[List[Tensor]] = None,
|
|
filter_loadops=True,
|
|
):
|
|
seen = set()
|
|
if to_prerealize:
|
|
for pre in to_prerealize:
|
|
for s in pre.lazydata.schedule(seen.copy()):
|
|
log_schedule_item(s)
|
|
seen.add(s.out)
|
|
sched = t.lazydata.schedule(seen)
|
|
for s in sched:
|
|
log_schedule_item(s)
|
|
if filter_loadops:
|
|
sched = [s for s in sched if s.ast.op not in LoadOps]
|
|
if len(sched) != allowed:
|
|
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
|
if len(sched) != allowed or DEBUG >= 3:
|
|
for i, s in enumerate(sched):
|
|
print("op", i)
|
|
print_tree(s.ast)
|
|
assert len(sched) == allowed
|
|
# test the (non loadops) ops linearize
|
|
for s in sched:
|
|
if s.ast.op in LoadOps:
|
|
continue
|
|
l = Linearizer(s.ast)
|
|
l.hand_coded_optimizations()
|
|
l.linearize()
|
|
|
|
|
|
class TestSchedule(unittest.TestCase):
|
|
def test_basic_binop_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = a + b + c
|
|
check_schedule(d, 1)
|
|
|
|
def test_basic_binop_fusion_deep(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = Tensor.empty(10)
|
|
e = a + b + c + d
|
|
check_schedule(e, 1)
|
|
|
|
def test_mulacc_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = (a * b).sum()
|
|
check_schedule(c, 1)
|
|
|
|
def test_mulacc_relu_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = (a * b).sum().relu()
|
|
check_schedule(c, 1)
|
|
|
|
def test_binop_reshape_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(5, 2)
|
|
d = (a + b).reshape(5, 2) + c
|
|
check_schedule(d, 1)
|
|
|
|
def test_binop_permute_fusion(self):
|
|
a = Tensor.empty(2, 5)
|
|
b = Tensor.empty(2, 5)
|
|
c = Tensor.empty(5, 2)
|
|
d = (a + b).permute(1, 0) + c
|
|
check_schedule(d, 1)
|
|
|
|
@unittest.skipIf(
|
|
not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM",
|
|
"only test for compiled backends",
|
|
)
|
|
def test_constants_are_embedded(self):
|
|
a = Tensor.empty(3, 3) * 2
|
|
check_schedule(a, 2, filter_loadops=False)
|
|
|
|
def test_binop_elu_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = a.elu()
|
|
check_schedule(b, 1)
|
|
|
|
def test_binop_reshape_reduce_fusion(self):
|
|
a = Tensor.empty(100)
|
|
b = Tensor.empty(100)
|
|
c = (a + b).reshape(10, 10).sum(axis=0, keepdim=True)
|
|
check_schedule(c, 1)
|
|
|
|
def test_reduce_reshape_binop_fusion(self):
|
|
a = Tensor.empty(10, 10)
|
|
b = Tensor.empty(10)
|
|
c = a.sum(axis=0) + b
|
|
check_schedule(c, 1)
|
|
|
|
@unittest.skip("not pushing permutes through reduces")
|
|
def test_reduce_permute_binop_fusion(self):
|
|
a = Tensor.empty(10, 10, 10)
|
|
b = Tensor.empty(10, 10, 1)
|
|
c = a.sum(axis=0, keepdim=True).permute(2, 1, 0) + b
|
|
check_schedule(c, 1)
|
|
|
|
def test_binop_early_reshape_reduce_fusion(self):
|
|
a = Tensor.empty(100)
|
|
b = Tensor.empty(100)
|
|
c = Tensor.empty(10, 10)
|
|
d = ((a + b).reshape(10, 10) + c).sum(axis=0)
|
|
check_schedule(d, 1)
|
|
|
|
def test_diamond_folded(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = Tensor.empty(10)
|
|
ab = a + b
|
|
e = (ab + c) + (ab + d)
|
|
check_schedule(e, 1)
|
|
|
|
def test_cache_binaryop(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a + b
|
|
d = a + b
|
|
check_schedule(d, 0, [c])
|
|
|
|
@unittest.skip("failing in old lazy")
|
|
def test_cache_binaryop_reshaped(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a + b
|
|
d = a.reshape(10, 1) + b.reshape(10, 1)
|
|
check_schedule(d, 0, [c])
|
|
|
|
def test_cache_binaryop_transpose(self):
|
|
a = Tensor.empty(10, 10)
|
|
b = Tensor.empty(10, 10)
|
|
c = (a.T * b.T).T # .contiguous()
|
|
d = a * b
|
|
check_schedule(d, 0, [c])
|
|
|
|
def test_cache_two_reduceops(self):
|
|
a = Tensor.empty(10)
|
|
b = a.sum()
|
|
c = a.sum()
|
|
bc = b + c
|
|
check_schedule(bc, 1)
|
|
|
|
def test_fold_double_unary(self):
|
|
y = Tensor.empty(2)
|
|
out = y.sum(keepdim=True).sqrt().__neg__()
|
|
check_schedule(out, 1)
|
|
|
|
# @unittest.skip("may want to reconsider this")
|
|
def test_fold_batchnorm(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(1, 32, 4, 4)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
|
out = bn(img)
|
|
check_schedule(out, 3)
|
|
|
|
def test_fold_conv_relu(self):
|
|
c1 = nn.Conv2d(3, 16, 3)
|
|
|
|
# run
|
|
img = Tensor.ones(2, 3, 64, 64)
|
|
out = c1(img).relu()
|
|
check_schedule(out, 1, [c1.weight, c1.bias])
|
|
|
|
def test_fold_conv_elu(self):
|
|
c1 = nn.Conv2d(3, 16, 3)
|
|
|
|
# run
|
|
img = Tensor.rand(2, 3, 64, 64)
|
|
out = c1(img).elu()
|
|
check_schedule(out, 1, [c1.weight, c1.bias])
|
|
|
|
def test_two_sum(self):
|
|
img = Tensor.empty(64, 64)
|
|
x = img.sum(0) + img.sum(1)
|
|
out = x.relu()
|
|
del x # is 3 without this
|
|
check_schedule(out, 2)
|
|
|
|
@unittest.skip("failing in old lazy")
|
|
def test_push_permute_through_reshape(self):
|
|
a = Tensor.empty(16, 16)
|
|
b = Tensor.empty(16, 16)
|
|
c = (a + b).reshape(4, 4, 4, 4).permute(2, 3, 0, 1).contiguous()
|
|
check_schedule(c, 1)
|
|
|
|
@unittest.skip("failing in old lazy")
|
|
def test_push_permute_through_reshape_alt(self):
|
|
a = Tensor.empty(4, 4, 4, 4)
|
|
b = Tensor.empty(4, 4, 4, 4)
|
|
c = (a + b).reshape(16, 16).permute(1, 0).contiguous()
|
|
check_schedule(c, 1)
|
|
|
|
def test_no_binop_rerun(self):
|
|
a = Tensor.empty(16)
|
|
b = Tensor.empty(16)
|
|
c = a + b
|
|
d = (a + b).reshape(16, 1)
|
|
check_schedule(d, 0, [c])
|
|
|
|
def test_multi_permute_should_collapse(self):
|
|
a = Tensor.empty(4, 4, 4, 4)
|
|
b = Tensor.empty(16)
|
|
c = (
|
|
a.sum((0, 1))
|
|
.cast(dtypes.float16)
|
|
.permute(1, 0)
|
|
.reshape(4, 4, 1)
|
|
.permute(1, 0, 2)
|
|
.reshape(16)
|
|
+ b
|
|
)
|
|
check_schedule(c, 1)
|
|
|
|
@unittest.skip("failing in old lazy")
|
|
def test_fancy_reshape_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a + b
|
|
d = a.reshape(10, 1) + b.reshape(10, 1)
|
|
out = c.sum() + d.sum()
|
|
check_schedule(out, 1)
|
|
|
|
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
|
|
@unittest.skip("not real world")
|
|
def test_children_dont_push(self):
|
|
a = Tensor.empty(10, 10, 1)
|
|
b = Tensor.empty(10, 10, 1)
|
|
d = (a + b).expand(10, 10, 10)
|
|
e = (a + b).permute(2, 1, 0)
|
|
f = d + e
|
|
check_schedule(f, 2)
|
|
|
|
def test_dont_fuse_binops_with_children(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
keep_me = a + b
|
|
e = (
|
|
keep_me.sum()
|
|
) # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
|
d = keep_me + c
|
|
check_schedule(d, 2)
|
|
check_schedule(keep_me, 0, [d])
|
|
|
|
@unittest.skip("failing in old lazy")
|
|
def test_permute_breaks_fusion(self):
|
|
a = Tensor.empty(10, 10, 10)
|
|
b = Tensor.empty(10, 10)
|
|
c = (a.sum(axis=2) + b).permute(1, 0)
|
|
d = c.permute(1, 0)
|
|
check_schedule(d, 1)
|
|
|
|
def test_some_permute_fusion(self):
|
|
a = Tensor.empty(8192, 16)
|
|
b = Tensor.empty(1, 16)
|
|
d = a.T + b.expand(8192, 16).T
|
|
c = a + b.expand(8192, 16)
|
|
e = d.T
|
|
check_schedule(c, 1)
|
|
check_schedule(e, 1)
|
|
|
|
# this is the failing case in openpilot...it's very simple like this
|
|
@unittest.skip("failing in old lazy")
|
|
def test_image_conv_fusion(self):
|
|
from tinygrad.features.image import image_conv2d
|
|
|
|
w1 = Tensor.empty(16, 16, 1, 1)
|
|
b1 = Tensor.empty(16)
|
|
w2 = Tensor.empty(16, 16, 1, 1)
|
|
b2 = Tensor.empty(16)
|
|
w3 = Tensor.empty(16, 16, 1, 1)
|
|
b3 = Tensor.empty(16)
|
|
|
|
x = Tensor.empty(1, 16, 32, 32)
|
|
x = base = image_conv2d(x, w1, b1)
|
|
x = image_conv2d(x, w2, b2) + base
|
|
x = image_conv2d(x, w3, b3)
|
|
|
|
# NOOP, 3 convs, contiguous
|
|
check_schedule(x, 5)
|
|
|
|
def test_image_conv_fusion_minimal(self):
|
|
b1 = Tensor.empty(16)
|
|
b2 = Tensor.empty(16)
|
|
|
|
def p(x):
|
|
return (
|
|
x.permute(1, 0)
|
|
.contiguous()
|
|
.reshape(32, 16, 1)
|
|
.expand(32, 16, 16)
|
|
.sum(axis=2)
|
|
.permute(1, 0)
|
|
)
|
|
|
|
x = Tensor.empty(16, 32)
|
|
x = base = p(x) + b1.reshape(16, 1)
|
|
x = p(x)
|
|
x = x + b2.reshape(16, 1)
|
|
x = x + base
|
|
del base
|
|
x = p(x)
|
|
check_schedule(x, 4)
|
|
|
|
def test_image_conv_fusion_more_minimal(self):
|
|
b1 = Tensor.empty(16)
|
|
|
|
def p(x):
|
|
return (
|
|
x.permute(1, 0)
|
|
.contiguous()
|
|
.reshape(32, 16, 1)
|
|
.expand(32, 16, 16)
|
|
.sum(axis=2)
|
|
.permute(1, 0)
|
|
)
|
|
|
|
x = Tensor.empty(16, 32)
|
|
x = base = p(x) + b1.reshape(16, 1)
|
|
x = p(x)
|
|
del base
|
|
check_schedule(x, 3)
|
|
|
|
def test_resnet_block(self):
|
|
from extra.models.resnet import BasicBlock
|
|
|
|
Tensor.training = False
|
|
bb = BasicBlock(64, 64)
|
|
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = bb(x)
|
|
check_schedule(out, 4)
|
|
|
|
def test_contiguous_while_contiguous(self):
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = x.contiguous()
|
|
check_schedule(out, 1, filter_loadops=False)
|
|
|
|
def test_contiguous_while_not_contiguous(self):
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = x.permute(0, 2, 3, 1).contiguous()
|
|
check_schedule(out, 2, filter_loadops=False)
|
|
|
|
def test_double_from(self):
|
|
x = Tensor([1, 2, 3, 4])
|
|
out = x.to("cpu")
|
|
check_schedule(out, 0, filter_loadops=False)
|
|
|
|
def test_pow_const_tensor(self):
|
|
x = Tensor([1, 2, 3, 4])
|
|
out = x ** Tensor(2)
|
|
check_schedule(out, 1)
|
|
|
|
def test_zero_size(self):
|
|
x = Tensor.rand(2, 3, 0)
|
|
out = x + 1
|
|
check_schedule(out, 0, filter_loadops=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|