1
0
Fork 0

3x3 winograd convs (#1675)

* winograd

* simplify local groups code

* comment

* respects self.opts.has_local

* always simplify ones

* make mypy happy

* move reshape, WINO flag

* wino flag, simple forward backward test for wino

* extra wino test

* merge oops

* comments

* axis_needs_valid -> axis_is_masked

* don't delete needs_valid (it's unused though)

* make linter happy

* make linter happy

* smaller test

* change number

* make wino tests very small
pull/1744/head
David Hou 2023-09-03 07:29:43 -07:00 committed by GitHub
parent c8025c319c
commit 3151d91f6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 118 additions and 16 deletions

View File

@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, Device
import pytest
pytestmark = [pytest.mark.exclude_cuda]
@ -64,6 +64,19 @@ class TestConv(unittest.TestCase):
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
Tensor.no_grad = False
@unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends")
def test_two_overlapping_binops_no_rerun_wino(self):
Tensor.no_grad = True
Tensor.wino = True
x = Tensor.randn(1,4,16,16)
w = Tensor.randn(6,4,3,3)
out = x.conv2d(w, padding=(1,1))
r1, r2 = out.relu(), out.elu()
np.testing.assert_allclose(r1.numpy(), np.maximum(out.numpy(), 0))
np.testing.assert_allclose(r2.numpy(), np.where(out.numpy() > 0, out.numpy(), (np.exp(out.numpy()) - 1)), atol=1e-5)
Tensor.wino = False
Tensor.no_grad = False
def test_first_three(self):
Tensor.no_grad = True
x = Tensor.rand(1,12,128,256)

View File

@ -129,6 +129,43 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipIf(Device.DEFAULT != "TORCH", "Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16
C2, K, S, P = 8, 3, 1, 1
Tensor.wino = True
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
layer.weight.requires_grad = True
layer.bias.requires_grad = True
# create in torch
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32))
torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32))
# test
x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
m = z.mean()
m.backward()
gw = layer.weight.grad.realize()
gb = layer.bias.grad.realize()
gx = x.grad.realize()
torch_z.mean().backward()
np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5)
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
Tensor.wino = False
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
def test_conv_transpose1d(self):
BS, C1, W = 4, 16, 224

View File

@ -1,4 +1,4 @@
from typing import Tuple, List, cast
from typing import Tuple, List, cast, Optional
import itertools, math, os
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp
@ -323,6 +323,20 @@ class OptimizedKernel(Kernel):
# **** below this line need to be optional and benchmarked ****
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(self.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
if self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]:
self.shift_to(axis, amount=self.full_shape[axis])
self.upcast()
self.simplify_ones()
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
@ -364,14 +378,14 @@ class OptimizedKernel(Kernel):
# **** local groups ****
if self.opts.has_local:
for axis in range(self.first_reduce - self.local_dims - 1, -1, -1):
local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
if self.full_shape[axis] == 1: continue
last_try = self.local_dims == 0 and axis == 0
if any(st.views[-1].strides[axis] == 0 for st in self.sts) or last_try:
for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
break
if self.local_dims >= 3: break
# prioritize making expand axes local
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
to_local: List[Tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None)
if local_sz is not None: to_local.append((axis, local_sz))
for axis, local_sz in sorted(to_local[:3]):
self.shift_to(axis, local_sz, insert_before=self.first_reduce)
self.local_dims += 1
self.simplify_ones()

View File

@ -9,6 +9,10 @@ class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
class ContiguousBackward(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
class Cast(Function):
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
self.input_dtype, self.bitcast = x.dtype, bitcast

View File

@ -197,6 +197,10 @@ class ShapeTracker:
def needs_valid(self) -> bool:
return any(v.mask is not None for v in self.views)
def axis_is_masked(self, axis) -> bool:
_, valid = self.expr_idxs()
return f'idx{axis}' in [v.expr for v in valid.vars()]
# *** under this line are the movement ops ***
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):

View File

@ -483,6 +483,7 @@ class Tensor:
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
wino = int(getenv("WINO", "0"))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
@ -492,11 +493,39 @@ class Tensor:
# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino:
# normal conv
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mat[i][j] * t[j] for j in range(len(mat[i]))), dim=dim+1) for i in range(len(mat))])
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time
# todo: stride == dilation
# use padding to round up to 4x4 output tiles
d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI)
d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx)
tyx = d.shape[-len(HWI):] # dim of tiling
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
# compute 6x6 winograd tiles: GgGt, BtdB
gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def dot(self, w:Tensor) -> Tensor:
n1, n2 = len(self.shape), len(w.shape)
@ -511,6 +540,7 @@ class Tensor:
# ***** mlops (unary) *****
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self)
def log2(self): return mlops.Log.apply(self)/math.log(2)
def exp(self): return mlops.Exp.apply(self)