1
0
Fork 0
tinygrab/test/test_schedule.py

387 lines
11 KiB
Python

# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import unittest
from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.helpers import DEBUG, dtypes
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad import nn
def check_schedule(
t: Tensor,
allowed: int,
to_prerealize: Optional[List[Tensor]] = None,
filter_loadops=True,
):
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in pre.lazydata.schedule(seen.copy()):
log_schedule_item(s)
seen.add(s.out)
sched = t.lazydata.schedule(seen)
for s in sched:
log_schedule_item(s)
if filter_loadops:
sched = [s for s in sched if s.ast.op not in LoadOps]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("op", i)
print_tree(s.ast)
assert len(sched) == allowed
# test the (non loadops) ops linearize
for s in sched:
if s.ast.op in LoadOps:
continue
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a + b + c
check_schedule(d, 1)
def test_basic_binop_fusion_deep(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
e = a + b + c + d
check_schedule(e, 1)
def test_mulacc_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a * b).sum()
check_schedule(c, 1)
def test_mulacc_relu_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a * b).sum().relu()
check_schedule(c, 1)
def test_binop_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(5, 2)
d = (a + b).reshape(5, 2) + c
check_schedule(d, 1)
def test_binop_permute_fusion(self):
a = Tensor.empty(2, 5)
b = Tensor.empty(2, 5)
c = Tensor.empty(5, 2)
d = (a + b).permute(1, 0) + c
check_schedule(d, 1)
@unittest.skipIf(
not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM",
"only test for compiled backends",
)
def test_constants_are_embedded(self):
a = Tensor.empty(3, 3) * 2
check_schedule(a, 2, filter_loadops=False)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()
check_schedule(b, 1)
def test_binop_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = (a + b).reshape(10, 10).sum(axis=0, keepdim=True)
check_schedule(c, 1)
def test_reduce_reshape_binop_fusion(self):
a = Tensor.empty(10, 10)
b = Tensor.empty(10)
c = a.sum(axis=0) + b
check_schedule(c, 1)
@unittest.skip("not pushing permutes through reduces")
def test_reduce_permute_binop_fusion(self):
a = Tensor.empty(10, 10, 10)
b = Tensor.empty(10, 10, 1)
c = a.sum(axis=0, keepdim=True).permute(2, 1, 0) + b
check_schedule(c, 1)
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = Tensor.empty(10, 10)
d = ((a + b).reshape(10, 10) + c).sum(axis=0)
check_schedule(d, 1)
def test_diamond_folded(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
ab = a + b
e = (ab + c) + (ab + d)
check_schedule(e, 1)
def test_cache_binaryop(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a + b
d = a + b
check_schedule(d, 0, [c])
@unittest.skip("failing in old lazy")
def test_cache_binaryop_reshaped(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a + b
d = a.reshape(10, 1) + b.reshape(10, 1)
check_schedule(d, 0, [c])
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10, 10)
b = Tensor.empty(10, 10)
c = (a.T * b.T).T # .contiguous()
d = a * b
check_schedule(d, 0, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
b = a.sum()
c = a.sum()
bc = b + c
check_schedule(bc, 1)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__()
check_schedule(out, 1)
# @unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1, 32, 4, 4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3, 16, 3)
# run
img = Tensor.ones(2, 3, 64, 64)
out = c1(img).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3, 16, 3)
# run
img = Tensor.rand(2, 3, 64, 64)
out = c1(img).elu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_two_sum(self):
img = Tensor.empty(64, 64)
x = img.sum(0) + img.sum(1)
out = x.relu()
del x # is 3 without this
check_schedule(out, 2)
@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(16, 16)
c = (a + b).reshape(4, 4, 4, 4).permute(2, 3, 0, 1).contiguous()
check_schedule(c, 1)
@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape_alt(self):
a = Tensor.empty(4, 4, 4, 4)
b = Tensor.empty(4, 4, 4, 4)
c = (a + b).reshape(16, 16).permute(1, 0).contiguous()
check_schedule(c, 1)
def test_no_binop_rerun(self):
a = Tensor.empty(16)
b = Tensor.empty(16)
c = a + b
d = (a + b).reshape(16, 1)
check_schedule(d, 0, [c])
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4, 4, 4, 4)
b = Tensor.empty(16)
c = (
a.sum((0, 1))
.cast(dtypes.float16)
.permute(1, 0)
.reshape(4, 4, 1)
.permute(1, 0, 2)
.reshape(16)
+ b
)
check_schedule(c, 1)
@unittest.skip("failing in old lazy")
def test_fancy_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a + b
d = a.reshape(10, 1) + b.reshape(10, 1)
out = c.sum() + d.sum()
check_schedule(out, 1)
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
@unittest.skip("not real world")
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
b = Tensor.empty(10, 10, 1)
d = (a + b).expand(10, 10, 10)
e = (a + b).permute(2, 1, 0)
f = d + e
check_schedule(f, 2)
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
keep_me = a + b
e = (
keep_me.sum()
) # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me + c
check_schedule(d, 2)
check_schedule(keep_me, 0, [d])
@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
a = Tensor.empty(10, 10, 10)
b = Tensor.empty(10, 10)
c = (a.sum(axis=2) + b).permute(1, 0)
d = c.permute(1, 0)
check_schedule(d, 1)
def test_some_permute_fusion(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(1, 16)
d = a.T + b.expand(8192, 16).T
c = a + b.expand(8192, 16)
e = d.T
check_schedule(c, 1)
check_schedule(e, 1)
# this is the failing case in openpilot...it's very simple like this
@unittest.skip("failing in old lazy")
def test_image_conv_fusion(self):
from tinygrad.features.image import image_conv2d
w1 = Tensor.empty(16, 16, 1, 1)
b1 = Tensor.empty(16)
w2 = Tensor.empty(16, 16, 1, 1)
b2 = Tensor.empty(16)
w3 = Tensor.empty(16, 16, 1, 1)
b3 = Tensor.empty(16)
x = Tensor.empty(1, 16, 32, 32)
x = base = image_conv2d(x, w1, b1)
x = image_conv2d(x, w2, b2) + base
x = image_conv2d(x, w3, b3)
# NOOP, 3 convs, contiguous
check_schedule(x, 5)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
b2 = Tensor.empty(16)
def p(x):
return (
x.permute(1, 0)
.contiguous()
.reshape(32, 16, 1)
.expand(32, 16, 16)
.sum(axis=2)
.permute(1, 0)
)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16, 1)
x = p(x)
x = x + b2.reshape(16, 1)
x = x + base
del base
x = p(x)
check_schedule(x, 4)
def test_image_conv_fusion_more_minimal(self):
b1 = Tensor.empty(16)
def p(x):
return (
x.permute(1, 0)
.contiguous()
.reshape(32, 16, 1)
.expand(32, 16, 16)
.sum(axis=2)
.permute(1, 0)
)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16, 1)
x = p(x)
del base
check_schedule(x, 3)
def test_resnet_block(self):
from extra.models.resnet import BasicBlock
Tensor.training = False
bb = BasicBlock(64, 64)
x = Tensor.empty(1, 64, 32, 32)
out = bb(x)
check_schedule(out, 4)
def test_contiguous_while_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.contiguous()
check_schedule(out, 1, filter_loadops=False)
def test_contiguous_while_not_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.permute(0, 2, 3, 1).contiguous()
check_schedule(out, 2, filter_loadops=False)
def test_double_from(self):
x = Tensor([1, 2, 3, 4])
out = x.to("cpu")
check_schedule(out, 0, filter_loadops=False)
def test_pow_const_tensor(self):
x = Tensor([1, 2, 3, 4])
out = x ** Tensor(2)
check_schedule(out, 1)
def test_zero_size(self):
x = Tensor.rand(2, 3, 0)
out = x + 1
check_schedule(out, 0, filter_loadops=False)
if __name__ == "__main__":
unittest.main(verbosity=2)