3x3 winograd convs (#1675)
* winograd * simplify local groups code * comment * respects self.opts.has_local * always simplify ones * make mypy happy * move reshape, WINO flag * wino flag, simple forward backward test for wino * extra wino test * merge oops * comments * axis_needs_valid -> axis_is_masked * don't delete needs_valid (it's unused though) * make linter happy * make linter happy * smaller test * change number * make wino tests very smallpull/1744/head
parent
c8025c319c
commit
3151d91f6e
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import pytest
|
||||
|
||||
pytestmark = [pytest.mark.exclude_cuda]
|
||||
|
@ -64,6 +64,19 @@ class TestConv(unittest.TestCase):
|
|||
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
|
||||
Tensor.no_grad = False
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends")
|
||||
def test_two_overlapping_binops_no_rerun_wino(self):
|
||||
Tensor.no_grad = True
|
||||
Tensor.wino = True
|
||||
x = Tensor.randn(1,4,16,16)
|
||||
w = Tensor.randn(6,4,3,3)
|
||||
out = x.conv2d(w, padding=(1,1))
|
||||
r1, r2 = out.relu(), out.elu()
|
||||
np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0))
|
||||
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
|
||||
Tensor.wino = False
|
||||
Tensor.no_grad = False
|
||||
|
||||
def test_first_three(self):
|
||||
Tensor.no_grad = True
|
||||
x = Tensor.rand(1,12,128,256)
|
||||
|
|
|
@ -129,6 +129,43 @@ class TestNN(unittest.TestCase):
|
|||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends")
|
||||
def test_conv2d_winograd(self):
|
||||
BS, C1, H, W = 2, 8, 16, 16
|
||||
C2, K, S, P = 8, 3, 1, 1
|
||||
|
||||
Tensor.wino = True
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
layer.weight.requires_grad = True
|
||||
layer.bias.requires_grad = True
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32))
|
||||
torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32))
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
m = z.mean()
|
||||
m.backward()
|
||||
gw = layer.weight.grad.realize()
|
||||
gb = layer.bias.grad.realize()
|
||||
gx = x.grad.realize()
|
||||
|
||||
torch_z.mean().backward()
|
||||
np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
Tensor.wino = False
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
||||
def test_conv_transpose1d(self):
|
||||
BS, C1, W = 4, 16, 224
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Tuple, List, cast
|
||||
from typing import Tuple, List, cast, Optional
|
||||
import itertools, math, os
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
|
||||
|
@ -323,6 +323,20 @@ class OptimizedKernel(Kernel):
|
|||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast = []
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
if self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
self.shift_to(axis, amount=self.full_shape[axis])
|
||||
self.upcast()
|
||||
self.simplify_ones()
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
|
@ -364,14 +378,14 @@ class OptimizedKernel(Kernel):
|
|||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
for axis in range(self.first_reduce - self.local_dims - 1, -1, -1):
|
||||
local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
if self.full_shape[axis] == 1: continue
|
||||
last_try = self.local_dims == 0 and axis == 0
|
||||
if any(st.views[-1].strides[axis] == 0 for st in self.sts) or last_try:
|
||||
for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
|
||||
self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
break
|
||||
if self.local_dims >= 3: break
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
|
||||
local_size = prod(sz for _, sz in to_local)
|
||||
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
|
||||
if local_sz is not None: to_local.append((axis, local_sz))
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
self.shift_to(axis, local_sz, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
self.simplify_ones()
|
||||
|
|
|
@ -9,6 +9,10 @@ class Contiguous(Function):
|
|||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
||||
|
||||
class ContiguousBackward(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
||||
|
||||
class Cast(Function):
|
||||
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
|
|
|
@ -197,6 +197,10 @@ class ShapeTracker:
|
|||
def needs_valid(self) -> bool:
|
||||
return any(v.mask is not None for v in self.views)
|
||||
|
||||
def axis_is_masked(self, axis) -> bool:
|
||||
_, valid = self.expr_idxs()
|
||||
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):
|
||||
|
|
|
@ -483,6 +483,7 @@ class Tensor:
|
|||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
wino = int(getenv("WINO", "0"))
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
|
||||
|
@ -492,11 +493,39 @@ class Tensor:
|
|||
# conv2d is a pooling op (with padding)
|
||||
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
|
||||
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
|
||||
# normal conv
|
||||
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
|
||||
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
|
||||
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mat[i][j] * t[j] for j in range(len(mat[i]))), dim=dim+1) for i in range(len(mat))])
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
||||
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
||||
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time
|
||||
|
||||
# todo: stride == dilation
|
||||
# use padding to round up to 4x4 output tiles
|
||||
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI)
|
||||
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
|
||||
tyx = d.shape[-len(HWI):] # dim of tiling
|
||||
|
||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
|
||||
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
|
@ -511,6 +540,7 @@ class Tensor:
|
|||
# ***** mlops (unary) *****
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log2(self): return mlops.Log.apply(self)/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
|
|
Loading…
Reference in New Issue